Skip to content

Conversation

Jack-Khuu
Copy link
Contributor

@Jack-Khuu Jack-Khuu commented Oct 14, 2025

The updates boil down to 2 changes that don't alter behavior in grpo/main:

  • Group is downgraded from a dataclass to a typedef oflist[Episode], since it's never required
  • Episode now directly holds a Completion with redundant attributes in Episode being removed
    • See df0e5a9 for how the redundant fields are mapped between Episode and ScoredCompletion.

(There's also various typehint improvements sprinkled in)


Note: This PR does not address or utilize Episode from data_models, but convergence is imminent

python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

Wandb looks roughly the same
Before: torchforge/grpo-training/runs/wca6wke2
After torchforge/grpo-training/runs/ul34xjr9

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 14, 2025
@Jack-Khuu Jack-Khuu marked this pull request as draft October 14, 2025 07:10
@Jack-Khuu Jack-Khuu changed the title [WIP] Flatten GRPO main: Group and Episode Flatten GRPO main: Group and Episode Oct 14, 2025
@Jack-Khuu Jack-Khuu marked this pull request as ready for review October 14, 2025 08:40
Copy link
Contributor

@JenniferWang JenniferWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 1 for removing the Group abstraction.

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.call_one(group)
for episode, advantage in zip(group.episodes, advantages):
advantages = await compute_advantages.compute.call_one(episodes)
Copy link
Contributor

@JenniferWang JenniferWang Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this diff but now since we're scrutinizing the main flow again, I think making compute_advantages its own Actor is very weird and probably the opposite to an "optimization"

  1. We do not expose capability to specify the hostmesh for a specific actor -- ideally, this should be collocated with the generator replica that produces this batch.
  2. ComputeAdvantage only needs the rewards; so very likely the entire episodes are serialized.

I wonder, if for now it should be just inlined in the sample call; or allocating a proc on the Policy mesh along side the PolicyWorker to handle the computation but chain the calls and return the result together in sample

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)

with policy.session() as s:
    host: HostMesh =  await s.get_host_mesh() # returns the host mesh associated with this replica
    advantages = host.run_task(compute_advantages) # where compute_advantages is a function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chained calls would be cool 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks legit; +1 on chained calls

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.call_one(group)
for episode, advantage in zip(group.episodes, advantages):
advantages = await compute_advantages.compute.call_one(episodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)

with policy.session() as s:
    host: HostMesh =  await s.get_host_mesh() # returns the host mesh associated with this replica
    advantages = host.run_task(compute_advantages) # where compute_advantages is a function

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff! Just a bunch of small comments.

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.call_one(group)
for episode, advantage in zip(group.episodes, advantages):
advantages = await compute_advantages.compute.call_one(episodes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chained calls would be cool 👀

@Jack-Khuu Jack-Khuu merged commit 4c14792 into main Oct 14, 2025
9 checks passed
@Jack-Khuu Jack-Khuu deleted the grpo-group branch October 14, 2025 21:42
allenwang28 pushed a commit to allenwang28/forge that referenced this pull request Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants