-
Notifications
You must be signed in to change notification settings - Fork 19
[RFC] Defining core abstractions #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Defining core abstractions Test Plan: Unit tests in the next PR once we have the stubs implemented
Convert a list of experiences to a minibatch. | ||
""" | ||
|
||
def pack_sequence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Felipe, I will take a look to see if I can use packed dataset instead?
""" | ||
applying accumulated gradients to the model parameters. | ||
|
||
the return type is a tuple of weights buffer, dtype, and shape of the original tensor. | ||
""" | ||
# TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we also apply gradients in snapshot_weights here? Or is it just wrong docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Wrong doc string, needs fixing.
def generate(self, prompt: Prompt, **kwargs) -> list[Completion]: | ||
""" | ||
Generate a completion given a prompt. | ||
Args: | ||
prompt: The input prompt. | ||
**kwargs: Additional model-specific generation parameters. | ||
Returns: | ||
str: The generated text. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to be one completion or multiple completions? docstring says one completion, return type says list of completions.
I think list of completions makes more sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is list of completions. doc string needs fixing.
class DistributedMetric(ABC): | ||
"""Metrics that are calculated in distributed fashion. | ||
|
||
Metrics computed in each rank are going to be wrapped in DistributedMetric |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get the point but I am confused about the api.
Can you give a more detailed example here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119
Happy to chat further.
|
||
|
||
@dataclass | ||
class Experience: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's say we have a multi-turn conversation. Does that correspond to one Experience, or multiple Experiences?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, are we supposed to have the following mask?
[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Experience class is generic and is agnostic to single-turn or multi-turn structure. It is up to your data processing pipeline to decide how to chunk conversations into Experience instances.
If you want to train on multi-turn conversations, you would typically:
- Concatenate all turns into a single sequence of token ids.
- Set the mask and weights appropriately
- Store this as a single Experience.
OR
If you want to treat each turn as a separate training example, you could split the conversation into multiple Experience instances, one per turn.
|
||
|
||
@dataclass | ||
class Prompt: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know about the name Prompt
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not picky about the names. The RL world has tons of jargon. Happy to pick one based on what team, agrees. (applies to everything in this diff)
from forge.data_models.prompt import Prompt | ||
|
||
|
||
class VLLMGenerator(Generator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and the Generator in api.py are supposed to be Monarch Actors, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. The intention is to distinguish between the implementation and the actors would be a very skinny version, where they are injecting a specific generator/trainer. keeping the actor as skinny as possible allows us to inject any generator implementation.
class Policy(Actor):
def __init__(self, generator: Generator):
super().__init__()
rank = current_rank()
self.rank = rank.rank
self.local_rank = rank["gpus"]
self.world_size = rank.extent.nelements
self._set_env_vars()
self.generator = generator
@endpoint
async def generate(self, prompt: Prompt) -> list[Completion]:
return self.generator.generate(prompt)
@endpoint
async def update_weights(
self, weights_buffer: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]]
):
return self.generator.update_weights(weights_buffer)
def _set_env_vars(self):
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["RANK"] = str(self.rank)
os.environ["LOCAL_RANK"] = str(self.local_rank)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Ritesh1905 for putting this together and for your patience! I took a look through and pointed out what was different with what we did, tagging people I think who are most relevant.
My overall 2c:
- I like the data abstractions - as it is, Forge is sort of the glue between Titan and vLLM. What I don't love is that some of the abstractions leak (like we have to convert a vLLM output into an experience, even in the train loop we have to do even more conversions etc.).
- Our actors are almost the same, except for the trainer. I would like to better understand the differences - overall I think that
train_step
is probably not enough functionality being exposed to the end user but I don't have enough context to understand how it's organized in TBR?
|
||
@abstractmethod | ||
def update_weights( | ||
self, weights_handle: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our update_weights
implementation is very similar, except the weights handle is essentially a handle given and tracked by torchstore
although now that I look at it, we should probably be passing the version explicitly rather than tracking implicitly in the policy cc @joecummings @pbontrager @pradeepfn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I have a TODO in the doc string, this needs to be a handle.
Also this needs to be abstracted to. Like we need to have a weights buffer abstractions and the implementation can figureout if it's a handle, raw/bytes etc.
Will update.
|
||
class Trainer(ABC): | ||
@abstractmethod | ||
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are accumulate/apply
kept separate?
aside, we should be able to implement this in our Titan trainer by pulling out the loss.backwards
into its own function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly computational efficiency.
This method would help reduces the frequency of expensive parameter updates
and potentially enables better utilization of hardware (Can batch operations like all-reduce in distributed scenarios). Essentially for every batch/minibatch you can don't need to update the params.. you can just accumulate them and then apply once a suitable number of batches has been trained.
Something like below...
For step in range(20):
for mini_batch in range(10):
# accumulate for N mini batches
trainer.accummulate_gradient(mini_batch).
trainer.apply_graidents()
We can expose another step
API in trainer... which would be a combination of accumulate and apply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, that makes sense to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok our titan integration doesn't yet have grad accumulation: #146 - it seems possible to add this API surface along with the high level step
pass | ||
|
||
@abstractmethod | ||
def snapshot_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the contract here that the snapshotted weights never change so long as this handle exists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. For a single RL loop run. So long the learner and policy have the same remote handle reference. then this can be achieved in 3 steps
- At the init, controller provides the same remote handle to learner and policy
- When it's time to update weights, controller requests trainer to push weights to remote buffer.
- Controller requests policy to read the buffer and update weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, yeah that makes sense. We're calling it push_weights
now but snapshot
may be more accurate
|
||
class Generator(ABC): | ||
@abstractmethod | ||
def generate(self, prompt: Prompt, **kwargs) -> list[Completion]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
our generate
returns the RequestOutput
in order to access the tokenized input IDs cc @pbontrager @Jack-Khuu
(but I see that the Completion
is not the vLLM CompletionOutput
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(but I see that the Completion is not the vLLM CompletionOutput)
This is intentional and the idea is to have the api very generic. Like we can have sglang/HF bases generators and not just vLLM based.
""" | ||
|
||
# Concatenated prompt and sample token ids. | ||
ids: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we currently do this concat in the train_step
which I don't love
|
||
|
||
@dataclass | ||
class Experience: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall I like this representation. It seems like the right primitive to pass to the trainer, avoiding the trainer having to care about implementation details on getting to the right format.
Also like the idea of from(...)
but my stylistic preference is like
@dataclass
class Experience:
@classmethod
def from_scored_completions(cls, ...) -> Experience:
...
cc @joecummings
|
||
Metrics computed in each rank are going to be wrapped in DistributedMetric | ||
according to how they are going to be aggregated. For example, average log prob | ||
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))
Is this how it's represented? I get the intention but I kind of hate the math DSL it implies lol
I wonder if there's any way we can use DTensor for this more elegantly, cc @LucasLLC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already did some work on metric logging aggregation. Its well tested. I think we should check if we can/should use it here: https://github.com/meta-pytorch/forge/tree/main/src/forge/data/dataset_metrics#2-metricsaggregator
When i did it, it was heavily focused on dataset, but it doesnt have to be. It should work for observability, rewards, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119
This is specifically needed if you are using the computational efficient accummulate and apply gradients approach rather than the single train step for every batch. Happy to chat further.
|
||
|
||
@dataclass | ||
class Completion: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to Experience
, I think we should incorporate atomic data representations into Forge with implemented converters like
@dataclass
class Completion:
def from_vllm_outputs(cls, RequestOutput) -> "Completion":
...
Different Version of this has landed in main |
Summary:
Test Plan:
Unit tests in the next PR once we have the stubs implemented