-
Notifications
You must be signed in to change notification settings - Fork 24
[RFC] Defining core abstractions #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from collections.abc import Iterator, Sequence | ||
from typing import Dict, Tuple | ||
|
||
import torch | ||
|
||
from forge.data_models.completion import Completion | ||
from forge.data_models.loss import LossOutput | ||
from forge.data_models.minibatch import Minibatch | ||
|
||
from forge.data_models.prompt import Prompt | ||
from forge.data_models.scored_completion import ScoredCompletion | ||
from torch.utils.data import DataLoader, IterableDataset | ||
|
||
|
||
# TODO: This file needs should not be in the data_models folder/package | ||
|
||
|
||
class Trainer(ABC): | ||
@abstractmethod | ||
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: | ||
""" | ||
accummulate_gradients is called once per minibatch. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def apply_gradients(self) -> None: | ||
""" | ||
applying accumulated gradients to the model parameters. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def snapshot_weights( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the contract here that the snapshotted weights never change so long as this handle exists? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. For a single RL loop run. So long the learner and policy have the same remote handle reference. then this can be achieved in 3 steps
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, yeah that makes sense. We're calling it |
||
self, | ||
) -> Dict[str, Tuple[torch.Tensor, torch.dtype, torch.Size]]: # TODO: RDMA buffer | ||
""" | ||
applying accumulated gradients to the model parameters. | ||
|
||
the return type is a tuple of weights buffer, dtype, and shape of the original tensor. | ||
""" | ||
# TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle | ||
pass | ||
Comment on lines
+44
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we also apply gradients in snapshot_weights here? Or is it just wrong docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Wrong doc string, needs fixing. |
||
|
||
|
||
class Generator(ABC): | ||
@abstractmethod | ||
def generate(self, prompt: Prompt, **kwargs) -> list[Completion]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. our (but I see that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is intentional and the idea is to have the api very generic. Like we can have sglang/HF bases generators and not just vLLM based. |
||
""" | ||
Generate a completion given a prompt. | ||
Args: | ||
prompt: The input prompt. | ||
**kwargs: Additional model-specific generation parameters. | ||
Returns: | ||
str: The generated text. | ||
Comment on lines
+55
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this supposed to be one completion or multiple completions? docstring says one completion, return type says list of completions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is list of completions. doc string needs fixing. |
||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def update_weights( | ||
self, weights_handle: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our although now that I look at it, we should probably be passing the version explicitly rather than tracking implicitly in the policy cc @joecummings @pbontrager @pradeepfn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. I have a TODO in the doc string, this needs to be a handle. Also this needs to be abstracted to. Like we need to have a weights buffer abstractions and the implementation can figureout if it's a handle, raw/bytes etc. Will update. |
||
): | ||
""" | ||
Update the weights of the model. | ||
Args: | ||
weights: A dictionary of weights to update. | ||
""" | ||
# TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle | ||
pass | ||
|
||
|
||
class Scorer(ABC): | ||
@abstractmethod | ||
def score(self, completion: Completion) -> ScoredCompletion: | ||
pass | ||
|
||
def score_batch( | ||
self, completions: Sequence[Completion] | ||
) -> Sequence[ScoredCompletion]: | ||
""" | ||
Optionally override for efficient batch scoring. | ||
""" | ||
return [self.score(c) for c in completions] | ||
|
||
|
||
class PromptDataset(IterableDataset): | ||
""" | ||
Users should inherit from this and implement __iter__. | ||
""" | ||
|
||
def __iter__(self) -> Iterator[Prompt]: | ||
""" | ||
defines how to generate or yield SimpleElement objects. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class PromptDataLoader(DataLoader): | ||
""" | ||
subclass of DataLoader to handles batching, parallelism, and other data serving concerns. | ||
""" | ||
|
||
def __init__(self, dataset: PromptDataset, **kwargs): | ||
super().__init__(dataset, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import torch | ||
from forge.data_models.prompt import Prompt | ||
|
||
|
||
@dataclass | ||
class Completion: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to
|
||
"""A model-generated completion for a given prompt.""" | ||
|
||
# The original prompt. | ||
prompt: Prompt | ||
|
||
# the decoded text returned by the model | ||
text: str | ||
|
||
# the encoded text (token ids) that were fed into the model | ||
prompt_ids: torch.Tensor | ||
|
||
# the encoded text (token ids) that were generated by the model | ||
token_ids: torch.Tensor | ||
|
||
# the log probabilities of the target tokens | ||
log_probs: Optional[torch.Tensor] = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
@dataclass | ||
class DistributedMetric(ABC): | ||
"""Metrics that are calculated in distributed fashion. | ||
|
||
Metrics computed in each rank are going to be wrapped in DistributedMetric | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I get the point but I am confused about the api. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119 Happy to chat further. |
||
according to how they are going to be aggregated. For example, average log prob | ||
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this how it's represented? I get the intention but I kind of hate the math DSL it implies lol I wonder if there's any way we can use DTensor for this more elegantly, cc @LucasLLC There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I already did some work on metric logging aggregation. Its well tested. I think we should check if we can/should use it here: https://github.com/meta-pytorch/forge/tree/main/src/forge/data/dataset_metrics#2-metricsaggregator When i did it, it was heavily focused on dataset, but it doesnt have to be. It should work for observability, rewards, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119 This is specifically needed if you are using the computational efficient accummulate and apply gradients approach rather than the single train step for every batch. Happy to chat further. |
||
`mask` indicates which token is valid. | ||
""" | ||
|
||
# We need to pass a context argument for distribution setup in the future. | ||
@abstractmethod | ||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def local(self) -> torch.Tensor: | ||
pass | ||
|
||
|
||
@dataclass | ||
class SumDistributedMetric(DistributedMetric): | ||
def __init__(self, tensor: torch.Tensor) -> None: | ||
self.tensor = tensor | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.tensor | ||
|
||
|
||
@dataclass | ||
class Fraction: | ||
numerator: DistributedMetric | ||
denominator: DistributedMetric | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return self.numerator.reduce(group) / self.denominator.reduce(group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.numerator.local() / self.denominator.local() | ||
|
||
|
||
def _try_clone_and_reduce( | ||
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None | ||
) -> torch.Tensor: | ||
cloned = tensor.detach().clone() | ||
if dist.is_initialized(): | ||
dist.all_reduce(cloned, op=op, group=group) | ||
return cloned |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Sequence | ||
|
||
import torch | ||
from forge.data_models.scored_completion import ScoredCompletion | ||
|
||
|
||
@dataclass | ||
class Experience: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's say we have a multi-turn conversation. Does that correspond to one Experience, or multiple Experiences? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, are we supposed to have the following mask? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Experience class is generic and is agnostic to single-turn or multi-turn structure. It is up to your data processing pipeline to decide how to chunk conversations into Experience instances. If you want to train on multi-turn conversations, you would typically:
OR If you want to treat each turn as a separate training example, you could split the conversation into multiple Experience instances, one per turn. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. overall I like this representation. It seems like the right primitive to pass to the trainer, avoiding the trainer having to care about implementation details on getting to the right format. Also like the idea of
cc @joecummings |
||
""" | ||
The Experience data class to be used by the trainer. | ||
|
||
Experiences are usually generated from a scored completion and running various post processing steps. | ||
""" | ||
|
||
# Concatenated prompt and sample token ids. | ||
ids: torch.Tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we currently do this concat in the |
||
|
||
# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. | ||
mask: torch.Tensor | ||
|
||
# The weight to apply to the loss of each target token. It's normally computed | ||
# from the advantage and the reward. | ||
weights: torch.Tensor | ||
|
||
# The log probabilities of the target tokens, for prompt part it's set to 0, | ||
# for generation part it's computed from the Generator/Sampler. | ||
log_probs: Optional[torch.Tensor] = None | ||
|
||
# TODO: add more fields as required | ||
state: str = "" | ||
|
||
|
||
def from_scored_completion(scored_completion: ScoredCompletion) -> Experience: | ||
"""Converts a ScoredCompletion to an Experience.""" | ||
prompt_ids = scored_completion.completion.prompt_ids | ||
token_ids = scored_completion.completion.token_ids | ||
log_probs = scored_completion.completion.log_probs | ||
ids = torch.cat([prompt_ids, token_ids]) | ||
mask = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
torch.ones_like(token_ids, dtype=torch.float32), | ||
] | ||
) | ||
advantage = scored_completion.score | ||
weights = mask * advantage | ||
log_probs = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
# TODO: this only works if sample.log_probs is 1 | ||
log_probs, | ||
] | ||
) | ||
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs) | ||
|
||
|
||
def from_scored_completions( | ||
scored_completions: Sequence[ScoredCompletion], | ||
) -> Sequence[Experience]: | ||
"""Converts a sequence of ScoredCompletion to a sequence of Experiences.""" | ||
return [from_scored_completion(sc) for sc in scored_completions] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
from forge.data_models.distributed_metric import Fraction | ||
from forge.data_models.minibatch import Minibatch | ||
|
||
|
||
@dataclass | ||
class LossInput: | ||
minibatch: Minibatch | ||
trainer_logits: torch.Tensor | ||
|
||
|
||
@dataclass | ||
class LossOutput: | ||
loss: Fraction |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import Any, Sequence | ||
|
||
import torch | ||
from forge.data_models.experience import Experience | ||
|
||
|
||
@dataclass | ||
class Minibatch: | ||
"""The minibatch that trainer will recieve.""" | ||
|
||
# The input sequence token ids for the trainer forward pass. | ||
input_ids: torch.Tensor | ||
|
||
# The segment ids for the input sequence token ids. Same segment | ||
# ids respresent the same sequence. | ||
segment_ids: torch.Tensor | ||
|
||
# The targets required for loss computation, usually concatenated prompt and | ||
# sample token ids. | ||
target_ids: torch.Tensor | ||
|
||
# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. | ||
target_mask: torch.Tensor | ||
|
||
# The weight to apply to the loss of each target token. It's normally computed | ||
# from the advantage and the reward. | ||
target_weights: torch.Tensor | ||
|
||
# The log probabilities of the target tokens, for prompt part it's set to 0, | ||
# for generation part it's computed from the sampler. | ||
target_log_probs: torch.Tensor | ||
|
||
|
||
def from_experiences( | ||
exps: Sequence[Experience], max_seq_len: int, pad_val: int = 0 | ||
) -> Minibatch: | ||
""" | ||
Convert a list of experiences to a minibatch. | ||
""" | ||
|
||
def pack_sequence( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks Felipe, I will take a look to see if I can use packed dataset instead? |
||
tensors: Sequence[torch.Tensor], | ||
pad_val: Any, | ||
dtype: torch.dtype, | ||
max_len: int, | ||
) -> torch.Tensor: | ||
"""Packs multiple tensors along the seq dim.""" | ||
seq = torch.cat(tensors) | ||
pad_len = max_len - seq.size(0) | ||
if pad_len < 0: | ||
raise ValueError( | ||
f"Sequence lenth {seq.size(0)} exceeds the maximum length {max_len}" | ||
) | ||
return torch.nn.functional.pad(seq, (0, pad_len), value=pad_val)[None, ...].to( | ||
dtype | ||
) | ||
|
||
mini_batch = {} | ||
exp_list = defaultdict(list) | ||
for i, exp in enumerate(exps): | ||
input_ids = exp.ids[:-1] | ||
exp_list["input_ids"].append(input_ids) | ||
exp_list["target_ids"].append(exp.ids[1:]) | ||
exp_list["segment_ids"].append(torch.ones_like(input_ids) * i) | ||
exp_list["target_mask"].append(exp.mask[1:]) | ||
exp_list["target_weights"].append(exp.weights[1:]) | ||
exp_list["target_log_probs"].append(exp.log_probs[1:]) | ||
|
||
for k, v in exp_list.items(): | ||
_dtype = torch.int64 | ||
if k == "target_mask" or k == "target_weights" or k == "target_log_probs": | ||
_dtype = torch.float32 | ||
|
||
mini_batch[k] = pack_sequence(v, pad_val, _dtype, max_seq_len) | ||
|
||
return Minibatch(**mini_batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are
accumulate/apply
kept separate?aside, we should be able to implement this in our Titan trainer by pulling out the
loss.backwards
into its own functioncc @wwwjn @felipemello1
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly computational efficiency.
This method would help reduces the frequency of expensive parameter updates
and potentially enables better utilization of hardware (Can batch operations like all-reduce in distributed scenarios). Essentially for every batch/minibatch you can don't need to update the params.. you can just accumulate them and then apply once a suitable number of batches has been trained.
Something like below...
We can expose another
step
API in trainer... which would be a combination of accumulate and apply.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, that makes sense to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok our titan integration doesn't yet have grad accumulation: #146 - it seems possible to add this API surface along with the high level
step