-
Notifications
You must be signed in to change notification settings - Fork 25
[2/N] Core trainer abstraction #158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict, Tuple | ||
|
||
from forge.data_models.loss import LossOutput | ||
from forge.data_models.minibatch import Minibatch | ||
|
||
|
||
# TODO: This file needs should NOT be in the data_models folder/package | ||
|
||
|
||
class Store(ABC): | ||
""" | ||
Abstract base class for a generic key-value store. | ||
|
||
This class defines the interface for a storage backend that can save and retrieve | ||
values using string keys. Subclasses should implement the actual storage logic, | ||
which could be in-memory, on disk, remote (e.g., RDMA, Redis), or any other backend. | ||
|
||
Example use cases include storing model weights, configuration objects, or any | ||
other data that needs to be accessed by key. | ||
|
||
Methods: | ||
put(key: str, value: Any) -> None | ||
Store a value under the specified key. | ||
|
||
get(key: str) -> Any | ||
Retrieve the value associated with the specified key. | ||
|
||
Subclasses must implement both methods. | ||
""" | ||
|
||
@abstractmethod | ||
def put(self, key: str, value: Any) -> None: | ||
"""Store a value under a key.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get(self, key: str) -> Any: | ||
"""Retrieve a value by key.""" | ||
pass | ||
|
||
|
||
class WeightsBuffer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow the reason to have this extra layer? Also a buffer is what holds some individual data, vs this would be the entire store? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At a high-level...
|
||
""" | ||
Concrete class for managing model weights using a generic key-value Store backend. | ||
This class provides a simple interface to store and retrieve model weights | ||
(or references to them) by delegating the actual storage logic to a Store instance. | ||
The Store abstraction allows for flexible backends (e.g., in-memory, RDMA, file system, torchstore etc.) | ||
without changing the WeightBuffer interface. | ||
Example usage: | ||
store = MyCustomStoreBackend() | ||
buffer = WeightBuffer(store) | ||
buffer.put("model_weights", weights) | ||
latest_weights = buffer.get("model_weights") | ||
Args: | ||
store (Store): An instance of a Store backend to use for storage. | ||
""" | ||
|
||
def __init__(self, store): | ||
""" | ||
Initialize the WeightBuffer with a given Store backend. | ||
Args: | ||
store (Store): The storage backend to use. | ||
""" | ||
self.store = store | ||
|
||
def put(self, key: str, weights): | ||
""" | ||
Store the given weights under the specified key. | ||
Args: | ||
key (str): The key under which to store the weights. | ||
weights: The weights object or reference to store. | ||
""" | ||
self.store.put(key, weights) | ||
|
||
def get(self, key: str): | ||
""" | ||
Retrieve the weights stored under the specified key. | ||
Args: | ||
key (str): The key for which to retrieve the weights. | ||
Returns: | ||
The weights object or reference associated with the key. | ||
""" | ||
return self.store.get(key) | ||
|
||
|
||
class Trainer(ABC): | ||
""" | ||
Abstract base class for a reinforcement learning (RL) trainer. | ||
This class defines the interface for any RL trainer implementation. | ||
It standardizes the methods required for gradient accumulation, applying updates, | ||
and snapshotting model weights. Subclasses should implement the actual logic | ||
for these operations, which may vary depending on the underlying model, | ||
framework, or distributed setup. | ||
""" | ||
|
||
@abstractmethod | ||
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: | ||
""" | ||
Accumulate gradients for the given minibatch. | ||
This method is called once per minibatch during training. It should compute | ||
the gradients for the minibatch and accumulate them (without applying them yet). | ||
|
||
Args: | ||
minibatch (Minibatch): The minibatch of data to use for gradient computation. | ||
Returns: | ||
LossOutput: The computed loss and any additional outputs needed for logging or analysis. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def apply_gradients(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general this looks fine, my main concern is whether this would be compatible with Compile and Pipeline parallel APIs? @H-Huang There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a concern if we want to expose another |
||
""" | ||
Apply accumulated gradients to the model parameters. | ||
This method should update the model's parameters using the gradients that have | ||
been accumulated so far (e.g., by calling an optimizer step). After this call, | ||
the accumulated gradients should be cleared/reset. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def snapshot_weights(self) -> WeightsBuffer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would likely push weights to store for checkpoint handling and weight sync to take over. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to update_weights in the policy, this will be somewhat dependent on the internal state of apply_gradients where you want to call it right after apply_gradients is done (without awaiting it) and then not call apply_gradients again until it has completed. Not as complex as the policy side, but something to keep in mind. |
||
""" | ||
Save the current model weights and return a buffer handle. | ||
This method should capture the current state of the model's weights and store | ||
them in a WeightBuffer (which may be local or remote, depending on the implementation). | ||
The returned buffer can be used to transfer weights between components or for checkpointing. | ||
Returns: | ||
WeightsBuffer: A handle or reference to the stored weights buffer. | ||
""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import torch | ||
from forge.data_models.prompt import Prompt | ||
|
||
|
||
@dataclass | ||
class Completion: | ||
"""A model-generated completion for a given prompt.""" | ||
|
||
# The original prompt. | ||
prompt: Prompt | ||
|
||
# the decoded text returned by the model | ||
text: str | ||
|
||
# the encoded text (token ids) that were fed into the model | ||
prompt_ids: torch.Tensor | ||
|
||
# the encoded text (token ids) that were generated by the model | ||
token_ids: torch.Tensor | ||
|
||
# the log probabilities of the target tokens | ||
log_probs: Optional[torch.Tensor] = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
@dataclass | ||
class DistributedMetric(ABC): | ||
"""Metrics that are calculated in distributed fashion. | ||
|
||
Metrics computed in each rank are going to be wrapped in DistributedMetric | ||
according to how they are going to be aggregated. For example, average log prob | ||
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where | ||
`mask` indicates which token is valid. | ||
""" | ||
|
||
# We need to pass a context argument for distribution setup in the future. | ||
@abstractmethod | ||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def local(self) -> torch.Tensor: | ||
pass | ||
|
||
|
||
@dataclass | ||
class SumDistributedMetric(DistributedMetric): | ||
def __init__(self, tensor: torch.Tensor) -> None: | ||
self.tensor = tensor | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.tensor | ||
|
||
|
||
@dataclass | ||
class Fraction: | ||
numerator: DistributedMetric | ||
denominator: DistributedMetric | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return self.numerator.reduce(group) / self.denominator.reduce(group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.numerator.local() / self.denominator.local() | ||
|
||
|
||
def _try_clone_and_reduce( | ||
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None | ||
) -> torch.Tensor: | ||
cloned = tensor.detach().clone() | ||
if dist.is_initialized(): | ||
dist.all_reduce(cloned, op=op, group=group) | ||
return cloned |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Sequence | ||
|
||
import torch | ||
from forge.data_models.scored_completion import ScoredCompletion | ||
|
||
|
||
@dataclass | ||
class Experience: | ||
""" | ||
The Experience data class to be used by the trainer. | ||
|
||
Experiences are usually generated from a scored completion and running various post processing steps. | ||
""" | ||
|
||
# Concatenated prompt and sample token ids. | ||
ids: torch.Tensor | ||
|
||
# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. | ||
mask: torch.Tensor | ||
|
||
# The weight to apply to the loss of each target token. It's normally computed | ||
# from the advantage and the reward. | ||
weights: torch.Tensor | ||
|
||
# The log probabilities of the target tokens, for prompt part it's set to 0, | ||
# for generation part it's computed from the Generator/Sampler. | ||
log_probs: Optional[torch.Tensor] = None | ||
|
||
# TODO: add more fields as required | ||
state: str = "" | ||
|
||
|
||
def from_scored_completion(scored_completion: ScoredCompletion) -> Experience: | ||
"""Converts a ScoredCompletion to an Experience.""" | ||
prompt_ids = scored_completion.completion.prompt_ids | ||
token_ids = scored_completion.completion.token_ids | ||
log_probs = scored_completion.completion.log_probs | ||
ids = torch.cat([prompt_ids, token_ids]) | ||
mask = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
torch.ones_like(token_ids, dtype=torch.float32), | ||
] | ||
) | ||
advantage = scored_completion.score | ||
weights = mask * advantage | ||
log_probs = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
# TODO: this only works if sample.log_probs is 1 | ||
log_probs, | ||
] | ||
) | ||
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs) | ||
|
||
|
||
def from_scored_completions( | ||
scored_completions: Sequence[ScoredCompletion], | ||
) -> Sequence[Experience]: | ||
"""Converts a sequence of ScoredCompletion to a sequence of Experiences.""" | ||
return [from_scored_completion(sc) for sc in scored_completions] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
from forge.data_models.distributed_metric import Fraction | ||
from forge.data_models.minibatch import Minibatch | ||
|
||
|
||
@dataclass | ||
class LossInput: | ||
minibatch: Minibatch | ||
trainer_logits: torch.Tensor | ||
|
||
|
||
@dataclass | ||
class LossOutput: | ||
loss: Fraction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left my comment in 3/N but to repeat here, is it valuable to abstract the buffer too? It's as core to the library as Monarch.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
buffer is just a wrapper on top of store, hence I did not do that. can you elaborate on your reasoning for abstracting buffer?
[EDIT]: Don't have an opinion but does not hurt to abstract the buffer too.