Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
111 changes: 111 additions & 0 deletions src/forge/data_models/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from typing import Dict, Tuple

import torch

from forge.data_models.completion import Completion
from forge.data_models.loss import LossOutput
from forge.data_models.minibatch import Minibatch

from forge.data_models.prompt import Prompt
from forge.data_models.scored_completion import ScoredCompletion
from torch.utils.data import DataLoader, IterableDataset


# TODO: This file needs should not be in the data_models folder/package


class Trainer(ABC):
@abstractmethod
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are accumulate/apply kept separate?

aside, we should be able to implement this in our Titan trainer by pulling out the loss.backwards into its own function

cc @wwwjn @felipemello1

Copy link
Contributor Author

@Ritesh1905 Ritesh1905 Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly computational efficiency.
This method would help reduces the frequency of expensive parameter updates
and potentially enables better utilization of hardware (Can batch operations like all-reduce in distributed scenarios). Essentially for every batch/minibatch you can don't need to update the params.. you can just accumulate them and then apply once a suitable number of batches has been trained.

Something like below...

For step in range(20):
    for mini_batch in range(10):
         # accumulate for N mini batches
          trainer.accummulate_gradient(mini_batch).
          
    trainer.apply_graidents()
        

We can expose another step API in trainer... which would be a combination of accumulate and apply.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, that makes sense to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok our titan integration doesn't yet have grad accumulation: #146 - it seems possible to add this API surface along with the high level step

"""
accummulate_gradients is called once per minibatch.
"""
pass

@abstractmethod
def apply_gradients(self) -> None:
"""
applying accumulated gradients to the model parameters.
"""
pass

@abstractmethod
def snapshot_weights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the contract here that the snapshotted weights never change so long as this handle exists?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For a single RL loop run. So long the learner and policy have the same remote handle reference. then this can be achieved in 3 steps

  1. At the init, controller provides the same remote handle to learner and policy
  2. When it's time to update weights, controller requests trainer to push weights to remote buffer.
  3. Controller requests policy to read the buffer and update weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, yeah that makes sense. We're calling it push_weights now but snapshot may be more accurate

self,
) -> Dict[str, Tuple[torch.Tensor, torch.dtype, torch.Size]]: # TODO: RDMA buffer
"""
applying accumulated gradients to the model parameters.

the return type is a tuple of weights buffer, dtype, and shape of the original tensor.
"""
# TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle
pass
Comment on lines +44 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we also apply gradients in snapshot_weights here? Or is it just wrong docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Wrong doc string, needs fixing.



class Generator(ABC):
@abstractmethod
def generate(self, prompt: Prompt, **kwargs) -> list[Completion]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our generate returns the RequestOutput in order to access the tokenized input IDs cc @pbontrager @Jack-Khuu

(but I see that the Completion is not the vLLM CompletionOutput)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(but I see that the Completion is not the vLLM CompletionOutput)

This is intentional and the idea is to have the api very generic. Like we can have sglang/HF bases generators and not just vLLM based.

"""
Generate a completion given a prompt.
Args:
prompt: The input prompt.
**kwargs: Additional model-specific generation parameters.
Returns:
str: The generated text.
Comment on lines +55 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be one completion or multiple completions? docstring says one completion, return type says list of completions.
I think list of completions makes more sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is list of completions. doc string needs fixing.

"""
pass

@abstractmethod
def update_weights(
self, weights_handle: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our update_weights implementation is very similar, except the weights handle is essentially a handle given and tracked by torchstore

although now that I look at it, we should probably be passing the version explicitly rather than tracking implicitly in the policy cc @joecummings @pbontrager @pradeepfn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I have a TODO in the doc string, this needs to be a handle.

Also this needs to be abstracted to. Like we need to have a weights buffer abstractions and the implementation can figureout if it's a handle, raw/bytes etc.

Will update.

):
"""
Update the weights of the model.
Args:
weights: A dictionary of weights to update.
"""
# TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle
pass


class Scorer(ABC):
@abstractmethod
def score(self, completion: Completion) -> ScoredCompletion:
pass

def score_batch(
self, completions: Sequence[Completion]
) -> Sequence[ScoredCompletion]:
"""
Optionally override for efficient batch scoring.
"""
return [self.score(c) for c in completions]


class PromptDataset(IterableDataset):
"""
Users should inherit from this and implement __iter__.
"""

def __iter__(self) -> Iterator[Prompt]:
"""
defines how to generate or yield SimpleElement objects.
"""
raise NotImplementedError


class PromptDataLoader(DataLoader):
"""
subclass of DataLoader to handles batching, parallelism, and other data serving concerns.
"""

def __init__(self, dataset: PromptDataset, **kwargs):
super().__init__(dataset, **kwargs)
31 changes: 31 additions & 0 deletions src/forge/data_models/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional

import torch
from forge.data_models.prompt import Prompt


@dataclass
class Completion:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to Experience, I think we should incorporate atomic data representations into Forge with implemented converters like

@dataclass
class Completion:
    def from_vllm_outputs(cls, RequestOutput) -> "Completion":
        ...

cc @pbontrager @Jack-Khuu @joecummings

"""A model-generated completion for a given prompt."""

# The original prompt.
prompt: Prompt

# the decoded text returned by the model
text: str

# the encoded text (token ids) that were fed into the model
prompt_ids: torch.Tensor

# the encoded text (token ids) that were generated by the model
token_ids: torch.Tensor

# the log probabilities of the target tokens
log_probs: Optional[torch.Tensor] = None
64 changes: 64 additions & 0 deletions src/forge/data_models/distributed_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
import torch.distributed as dist


@dataclass
class DistributedMetric(ABC):
"""Metrics that are calculated in distributed fashion.

Metrics computed in each rank are going to be wrapped in DistributedMetric
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the point but I am confused about the api.
Can you give a more detailed example here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119

Happy to chat further.

according to how they are going to be aggregated. For example, average log prob
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))

Is this how it's represented? I get the intention but I kind of hate the math DSL it implies lol

cc @felipemello1

I wonder if there's any way we can use DTensor for this more elegantly, cc @LucasLLC

Copy link
Contributor

@felipemello1 felipemello1 Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already did some work on metric logging aggregation. Its well tested. I think we should check if we can/should use it here: https://github.com/meta-pytorch/forge/tree/main/src/forge/data/dataset_metrics#2-metricsaggregator

When i did it, it was heavily focused on dataset, but it doesnt have to be. It should work for observability, rewards, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119

This is specifically needed if you are using the computational efficient accummulate and apply gradients approach rather than the single train step for every batch. Happy to chat further.

`mask` indicates which token is valid.
"""

# We need to pass a context argument for distribution setup in the future.
@abstractmethod
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
pass

@abstractmethod
def local(self) -> torch.Tensor:
pass


@dataclass
class SumDistributedMetric(DistributedMetric):
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group)

def local(self) -> torch.Tensor:
return self.tensor


@dataclass
class Fraction:
numerator: DistributedMetric
denominator: DistributedMetric

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return self.numerator.reduce(group) / self.denominator.reduce(group)

def local(self) -> torch.Tensor:
return self.numerator.local() / self.denominator.local()


def _try_clone_and_reduce(
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None
) -> torch.Tensor:
cloned = tensor.detach().clone()
if dist.is_initialized():
dist.all_reduce(cloned, op=op, group=group)
return cloned
68 changes: 68 additions & 0 deletions src/forge/data_models/experience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional, Sequence

import torch
from forge.data_models.scored_completion import ScoredCompletion


@dataclass
class Experience:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say we have a multi-turn conversation. Does that correspond to one Experience, or multiple Experiences?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, are we supposed to have the following mask?
[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Experience class is generic and is agnostic to single-turn or multi-turn structure. It is up to your data processing pipeline to decide how to chunk conversations into Experience instances.

If you want to train on multi-turn conversations, you would typically:

  • Concatenate all turns into a single sequence of token ids.
  • Set the mask and weights appropriately
  • Store this as a single Experience.

OR

If you want to treat each turn as a separate training example, you could split the conversation into multiple Experience instances, one per turn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall I like this representation. It seems like the right primitive to pass to the trainer, avoiding the trainer having to care about implementation details on getting to the right format.

Also like the idea of from(...) but my stylistic preference is like

@dataclass
class Experience:

    @classmethod
    def from_scored_completions(cls, ...) -> Experience:
        ...

cc @joecummings

"""
The Experience data class to be used by the trainer.

Experiences are usually generated from a scored completion and running various post processing steps.
"""

# Concatenated prompt and sample token ids.
ids: torch.Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we currently do this concat in the train_step which I don't love


# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
mask: torch.Tensor

# The weight to apply to the loss of each target token. It's normally computed
# from the advantage and the reward.
weights: torch.Tensor

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the Generator/Sampler.
log_probs: Optional[torch.Tensor] = None

# TODO: add more fields as required
state: str = ""


def from_scored_completion(scored_completion: ScoredCompletion) -> Experience:
"""Converts a ScoredCompletion to an Experience."""
prompt_ids = scored_completion.completion.prompt_ids
token_ids = scored_completion.completion.token_ids
log_probs = scored_completion.completion.log_probs
ids = torch.cat([prompt_ids, token_ids])
mask = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
torch.ones_like(token_ids, dtype=torch.float32),
]
)
advantage = scored_completion.score
weights = mask * advantage
log_probs = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
# TODO: this only works if sample.log_probs is 1
log_probs,
]
)
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs)


def from_scored_completions(
scored_completions: Sequence[ScoredCompletion],
) -> Sequence[Experience]:
"""Converts a sequence of ScoredCompletion to a sequence of Experiences."""
return [from_scored_completion(sc) for sc in scored_completions]
22 changes: 22 additions & 0 deletions src/forge/data_models/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

import torch
from forge.data_models.distributed_metric import Fraction
from forge.data_models.minibatch import Minibatch


@dataclass
class LossInput:
minibatch: Minibatch
trainer_logits: torch.Tensor


@dataclass
class LossOutput:
loss: Fraction
84 changes: 84 additions & 0 deletions src/forge/data_models/minibatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Sequence

import torch
from forge.data_models.experience import Experience


@dataclass
class Minibatch:
"""The minibatch that trainer will recieve."""

# The input sequence token ids for the trainer forward pass.
input_ids: torch.Tensor

# The segment ids for the input sequence token ids. Same segment
# ids respresent the same sequence.
segment_ids: torch.Tensor

# The targets required for loss computation, usually concatenated prompt and
# sample token ids.
target_ids: torch.Tensor

# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
target_mask: torch.Tensor

# The weight to apply to the loss of each target token. It's normally computed
# from the advantage and the reward.
target_weights: torch.Tensor

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the sampler.
target_log_probs: torch.Tensor


def from_experiences(
exps: Sequence[Experience], max_seq_len: int, pad_val: int = 0
) -> Minibatch:
"""
Convert a list of experiences to a minibatch.
"""

def pack_sequence(
Copy link
Contributor

@felipemello1 felipemello1 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Felipe, I will take a look to see if I can use packed dataset instead?

tensors: Sequence[torch.Tensor],
pad_val: Any,
dtype: torch.dtype,
max_len: int,
) -> torch.Tensor:
"""Packs multiple tensors along the seq dim."""
seq = torch.cat(tensors)
pad_len = max_len - seq.size(0)
if pad_len < 0:
raise ValueError(
f"Sequence lenth {seq.size(0)} exceeds the maximum length {max_len}"
)
return torch.nn.functional.pad(seq, (0, pad_len), value=pad_val)[None, ...].to(
dtype
)

mini_batch = {}
exp_list = defaultdict(list)
for i, exp in enumerate(exps):
input_ids = exp.ids[:-1]
exp_list["input_ids"].append(input_ids)
exp_list["target_ids"].append(exp.ids[1:])
exp_list["segment_ids"].append(torch.ones_like(input_ids) * i)
exp_list["target_mask"].append(exp.mask[1:])
exp_list["target_weights"].append(exp.weights[1:])
exp_list["target_log_probs"].append(exp.log_probs[1:])

for k, v in exp_list.items():
_dtype = torch.int64
if k == "target_mask" or k == "target_weights" or k == "target_log_probs":
_dtype = torch.float32

mini_batch[k] = pack_sequence(v, pad_val, _dtype, max_seq_len)

return Minibatch(**mini_batch)
Loading
Loading