Skip to content

Conversation

joecummings
Copy link
Member

@joecummings joecummings commented Aug 18, 2025

What this PR DOES do:

Core GRPO Implementation:

  • Complete GRPO training system for RL fine-tuning
  • Integration with vLLM policy actors for text generation during rollouts
  • Multi-reward system supporting math correctness and thinking tag rewards
  • Replay buffer with configurable batch sizes and policy version tracking
  • Async training loops with concurrent rollout generation and policy training

Service Infrastructure:

  • Service-based architecture using Monarch actors for distributed components
  • Policy actors with configurable sampling, device assignment, and tensor parallelism
  • Trainer actors with GRPO loss computation, KL regularization, and gradient clipping
  • Reference model actors for computing baseline log probabilities
  • Advantage computation using reward-to-go with normalization
  • Dataset integration with GSM8K math problems

What this PR will NOT cover:

Missing Features:

  • Weight synchronization between trainer and policy (commented out: # await trainer.update_weights(policy)) @pradeepfn @pbontrager
  • Policy versioning system (hardcoded to version 0) @joecummings @Jack-Khuu
  • Logging/monitoring infrastructure (No wandb / tensorboard integration) @calvinpelletier @DNXie
  • Multi-turn conversations or complex dialog handling
  • Production-ready error handling and recovery mechanisms @allenwang28

Incomplete Work:

  • TODO: Move policy processing initialization into setup (line 453 in **main.py**) @Jack-Khuu
  • Manual device assignment instead of automatic GPU allocation @pbontrager
  • Hardcoded hyperparameters throughout the system

Technical Limitations:

  • Single model support (Qwen3-1.7B)
  • Fixed reward functions (no pluggable reward interface) @DNXie
  • No distributed training across multiple machines

This is a working prototype of GRPO with the core RL loop functional but requiring additional work for production deployment, proper weight updates, and comprehensive logging.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 18, 2025
training_step = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
if batch is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this should just be inside of the buffer.sample method. Then you'll just await until there's enough data for a sample.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I disagree. The contract here is just that the buffer will return a sample here when it computationally can, not when there's something inside.

In addition, this would push possible errors a layer down if the replay buffer isn't getting filled for some reason. I'd rather have this be exposed logic to the user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does sample return None? Is it when number of usable rollouts < batch_size? I feel like this is going to surprise some people who try to use this and start running into Nones in their batch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if the replay buffer has nothing in it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think exposing logic to the user makes sense, my only nit is that I might want it to have more of a queue-like semantic? Like we set a timeout, it raises an error if we hit the timeout etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does choose have a built-in timeout feature?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't, but the buffer itself can return an exception which will be propagated through choose

but the current service implementation will then mark the replica as unhealthy so we shouldn't add this yet

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bunch of initial questions -- I understand this is not really meant to be the final form of the GRPO loop so let me know which gaps are deliberate hacks to unblock vs ones that require more thought. (Ideally just put this in the PR summary to be explicit.)

Would also be good to start factoring out some stuff like rewards, actors, etc into separate files so that main.py starts looking closer to a pure training script.

) # Remove batch dimension for single response


class DatasetActor(ForgeActor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Longer-term what is our plan here? (I.e. is our HfIterableDataset compatible with how we're setting things up here?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to re-build up our HfIterableDataset based on the patterns we observe when testing out this training loop. Right now, it's not super intuitive to use and it's also not clear what we actually need from there.

My guess is that we'll keep 75% of the current implementation, but I don't want to assume before we actually start testing.

Comment on lines +375 to +376
ds = ds.map(gsm8k_to_messages)
ds = ds.shuffle()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice no split_dataset_by_node. Is this by design?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, just missing it. Will add.

avg_reward = sum(group.reward for group in episode.groups) / len(
episode.groups
)
wandb.log({"rollout_count": rollout_count, "avg_reward": avg_reward})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this eventually be its own actor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it needs to be unless we're making so many requests from all over to log things and we want it going through a central source.

Otherwise, we can just treat the logger as a normal component.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would wandb complain about having multiple "sessions" created from the same job? If we can minimize the messages being passed in Monarch that generally seems like a good idea

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want it to be it's own actor so logs are called from within actors too. It's a lot of unnecessary data passing otherwise, especially if people log artifacts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I'm not sure if I understand that comment @pbontrager

if the metrics logger is an actor, then all actors would need to pass messages over Monarch to the metrics actor right? If actors can just write directly to wandb that would minimize data passing?

self.lambda_ = lambda_ # GAE lambda parameter

@endpoint
async def __call__(self, groups: list[Group]) -> list[float]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this based off of some specific reference implementation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My brain, the old impl of GRPO in torchtune, and good ol devmate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I trust that consensus..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're probably not going to want a Group type, but we put the type work on the back burner for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should spin up a track fully dedicated to ensuring numerical correctness (which I think is the basis of @ebsmothers comment), but imo the right timing for this is after we have the working prototype

@endpoint
async def add(self, trajectory: Trajectory) -> None:
self.buffer.append(trajectory)
async def add(self, episode) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are moving from Trajectories to Episodes in the replay buffer? Isn't this overly-tailored to GRPO?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything in here is overly tailored to GRPO. I want to base our design decisions from real life, otherwise we can talk all day. This is definitely not the final state of our APIs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Episode isn't GRPO specific, it's basically another word for Trajectory.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@pbontrager pbontrager mentioned this pull request Aug 26, 2025
@joecummings joecummings changed the title [WIP] Skeleton of GRPO Skeleton of GRPO Aug 27, 2025
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is awesome, great work @joecummings!

self.prompt = prompt
self.target = target
self.policy_version = policy_version
self.groups: list[Group] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming nit, but isn't what we're calling Group here actually an individual output, and what we're calling self.groups the actual Group?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair lol

ref_logprobs_list = []
advantages_list = []

for group in groups:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no action needed now, but imo it would be nifty for the episode structure itself to handle things like this

for episode in batch:
    tokenized = self.tokenizer(
        episode.as_response_group(),
        ...
    )

training_step = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
if batch is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think exposing logic to the user makes sense, my only nit is that I might want it to have more of a queue-like semantic? Like we set a timeout, it raises an error if we hit the timeout etc.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great! There will be a lot of refactor work going forward so I won't push for changes here.

@endpoint
async def setup(self):
# Set up policy_worker
self.available_devices = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really come from the script and not the config. Ideally the script makes the master set of ranks and then passes them to the actors.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@joecummings joecummings merged commit 6d76a41 into meta-pytorch:main Aug 28, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants