Skip to content

Conversation

Ritesh1905
Copy link
Contributor

@Ritesh1905 Ritesh1905 commented Sep 11, 2025

Summary:

  • Defining core abstractions.
  • Hoping to use this PR to get feedback on the abstractions
  • Not picky about the terminology (Prompt vs experience vs completion vs minibatch etc.) So please suggest.

Test Plan:
Unit tests in the next PR once we have the stubs implemented

Summary:
Defining core abstractions

Test Plan:
Unit tests in the next PR once we have the stubs implemented
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 11, 2025
@Ritesh1905 Ritesh1905 marked this pull request as ready for review September 11, 2025 18:44
Convert a list of experiences to a minibatch.
"""

def pack_sequence(
Copy link
Contributor

@felipemello1 felipemello1 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Felipe, I will take a look to see if I can use packed dataset instead?

Comment on lines +44 to +50
"""
applying accumulated gradients to the model parameters.

the return type is a tuple of weights buffer, dtype, and shape of the original tensor.
"""
# TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we also apply gradients in snapshot_weights here? Or is it just wrong docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Wrong doc string, needs fixing.

Comment on lines +55 to +62
def generate(self, prompt: Prompt, **kwargs) -> list[Completion]:
"""
Generate a completion given a prompt.
Args:
prompt: The input prompt.
**kwargs: Additional model-specific generation parameters.
Returns:
str: The generated text.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be one completion or multiple completions? docstring says one completion, return type says list of completions.
I think list of completions makes more sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is list of completions. doc string needs fixing.

class DistributedMetric(ABC):
"""Metrics that are calculated in distributed fashion.

Metrics computed in each rank are going to be wrapped in DistributedMetric
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the point but I am confused about the api.
Can you give a more detailed example here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119

Happy to chat further.



@dataclass
class Experience:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say we have a multi-turn conversation. Does that correspond to one Experience, or multiple Experiences?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, are we supposed to have the following mask?
[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Experience class is generic and is agnostic to single-turn or multi-turn structure. It is up to your data processing pipeline to decide how to chunk conversations into Experience instances.

If you want to train on multi-turn conversations, you would typically:

  • Concatenate all turns into a single sequence of token ids.
  • Set the mask and weights appropriately
  • Store this as a single Experience.

OR

If you want to treat each turn as a separate training example, you could split the conversation into multiple Experience instances, one per turn.



@dataclass
class Prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about the name Prompt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not picky about the names. The RL world has tons of jargon. Happy to pick one based on what team, agrees. (applies to everything in this diff)

from forge.data_models.prompt import Prompt


class VLLMGenerator(Generator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the Generator in api.py are supposed to be Monarch Actors, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. The intention is to distinguish between the implementation and the actors would be a very skinny version, where they are injecting a specific generator/trainer. keeping the actor as skinny as possible allows us to inject any generator implementation.

class Policy(Actor):
    def __init__(self, generator: Generator):
        super().__init__()
        rank = current_rank()
        self.rank = rank.rank
        self.local_rank = rank["gpus"]
        self.world_size = rank.extent.nelements
        self._set_env_vars()
        self.generator = generator

    @endpoint
    async def generate(self, prompt: Prompt) -> list[Completion]:
        return self.generator.generate(prompt)

    @endpoint
    async def update_weights(
        self, weights_buffer: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]]
    ):
        return self.generator.update_weights(weights_buffer)

    def _set_env_vars(self):
        os.environ["WORLD_SIZE"] = str(self.world_size)
        os.environ["RANK"] = str(self.rank)
        os.environ["LOCAL_RANK"] = str(self.local_rank)

Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Ritesh1905 for putting this together and for your patience! I took a look through and pointed out what was different with what we did, tagging people I think who are most relevant.

My overall 2c:

  • I like the data abstractions - as it is, Forge is sort of the glue between Titan and vLLM. What I don't love is that some of the abstractions leak (like we have to convert a vLLM output into an experience, even in the train loop we have to do even more conversions etc.).
  • Our actors are almost the same, except for the trainer. I would like to better understand the differences - overall I think that train_step is probably not enough functionality being exposed to the end user but I don't have enough context to understand how it's organized in TBR?


@abstractmethod
def update_weights(
self, weights_handle: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our update_weights implementation is very similar, except the weights handle is essentially a handle given and tracked by torchstore

although now that I look at it, we should probably be passing the version explicitly rather than tracking implicitly in the policy cc @joecummings @pbontrager @pradeepfn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I have a TODO in the doc string, this needs to be a handle.

Also this needs to be abstracted to. Like we need to have a weights buffer abstractions and the implementation can figureout if it's a handle, raw/bytes etc.

Will update.


class Trainer(ABC):
@abstractmethod
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are accumulate/apply kept separate?

aside, we should be able to implement this in our Titan trainer by pulling out the loss.backwards into its own function

cc @wwwjn @felipemello1

Copy link
Contributor Author

@Ritesh1905 Ritesh1905 Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly computational efficiency.
This method would help reduces the frequency of expensive parameter updates
and potentially enables better utilization of hardware (Can batch operations like all-reduce in distributed scenarios). Essentially for every batch/minibatch you can don't need to update the params.. you can just accumulate them and then apply once a suitable number of batches has been trained.

Something like below...

For step in range(20):
    for mini_batch in range(10):
         # accumulate for N mini batches
          trainer.accummulate_gradient(mini_batch).
          
    trainer.apply_graidents()
        

We can expose another step API in trainer... which would be a combination of accumulate and apply.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, that makes sense to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok our titan integration doesn't yet have grad accumulation: #146 - it seems possible to add this API surface along with the high level step

pass

@abstractmethod
def snapshot_weights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the contract here that the snapshotted weights never change so long as this handle exists?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For a single RL loop run. So long the learner and policy have the same remote handle reference. then this can be achieved in 3 steps

  1. At the init, controller provides the same remote handle to learner and policy
  2. When it's time to update weights, controller requests trainer to push weights to remote buffer.
  3. Controller requests policy to read the buffer and update weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, yeah that makes sense. We're calling it push_weights now but snapshot may be more accurate


class Generator(ABC):
@abstractmethod
def generate(self, prompt: Prompt, **kwargs) -> list[Completion]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our generate returns the RequestOutput in order to access the tokenized input IDs cc @pbontrager @Jack-Khuu

(but I see that the Completion is not the vLLM CompletionOutput)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(but I see that the Completion is not the vLLM CompletionOutput)

This is intentional and the idea is to have the api very generic. Like we can have sglang/HF bases generators and not just vLLM based.

"""

# Concatenated prompt and sample token ids.
ids: torch.Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we currently do this concat in the train_step which I don't love



@dataclass
class Experience:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall I like this representation. It seems like the right primitive to pass to the trainer, avoiding the trainer having to care about implementation details on getting to the right format.

Also like the idea of from(...) but my stylistic preference is like

@dataclass
class Experience:

    @classmethod
    def from_scored_completions(cls, ...) -> Experience:
        ...

cc @joecummings


Metrics computed in each rank are going to be wrapped in DistributedMetric
according to how they are going to be aggregated. For example, average log prob
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))

Is this how it's represented? I get the intention but I kind of hate the math DSL it implies lol

cc @felipemello1

I wonder if there's any way we can use DTensor for this more elegantly, cc @LucasLLC

Copy link
Contributor

@felipemello1 felipemello1 Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already did some work on metric logging aggregation. Its well tested. I think we should check if we can/should use it here: https://github.com/meta-pytorch/forge/tree/main/src/forge/data/dataset_metrics#2-metricsaggregator

When i did it, it was heavily focused on dataset, but it doesnt have to be. It should work for observability, rewards, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have something in my private branch which provides a motivation for how this will be used. https://github.com/meta-pytorch/forge/blob/rithesh/reinforce/src/forge/trainers/huggingface_trainer.py#L61-L119

This is specifically needed if you are using the computational efficient accummulate and apply gradients approach rather than the single train step for every batch. Happy to chat further.



@dataclass
class Completion:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to Experience, I think we should incorporate atomic data representations into Forge with implemented converters like

@dataclass
class Completion:
    def from_vllm_outputs(cls, RequestOutput) -> "Completion":
        ...

cc @pbontrager @Jack-Khuu @joecummings

@Jack-Khuu
Copy link
Contributor

Different Version of this has landed in main

@Jack-Khuu Jack-Khuu closed this Oct 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants