-
Notifications
You must be signed in to change notification settings - Fork 16
Flatten GRPO main: Group and Episode #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 1 for removing the
Group
abstraction.
# Calculate advantages and add to replay buffer | ||
advantages = await compute_advantages.compute.call_one(group) | ||
for episode, advantage in zip(group.episodes, advantages): | ||
advantages = await compute_advantages.compute.call_one(episodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this diff but now since we're scrutinizing the main flow again, I think making compute_advantages
its own Actor
is very weird and probably the opposite to an "optimization"
- We do not expose capability to specify the hostmesh for a specific actor -- ideally, this should be collocated with the generator replica that produces this batch.
- ComputeAdvantage only needs the rewards; so very likely the entire episodes are serialized.
I wonder, if for now it should be just inlined in the sample
call; or allocating a proc on the Policy mesh along side the PolicyWorker to handle the computation but chain the calls and return the result together in sample
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)
with policy.session() as s:
host: HostMesh = await s.get_host_mesh() # returns the host mesh associated with this replica
advantages = host.run_task(compute_advantages) # where compute_advantages is a function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chained calls would be cool 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks legit; +1 on chained calls
# Calculate advantages and add to replay buffer | ||
advantages = await compute_advantages.compute.call_one(group) | ||
for episode, advantage in zip(group.episodes, advantages): | ||
advantages = await compute_advantages.compute.call_one(episodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)
with policy.session() as s:
host: HostMesh = await s.get_host_mesh() # returns the host mesh associated with this replica
advantages = host.run_task(compute_advantages) # where compute_advantages is a function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome stuff! Just a bunch of small comments.
# Calculate advantages and add to replay buffer | ||
advantages = await compute_advantages.compute.call_one(group) | ||
for episode, advantage in zip(group.episodes, advantages): | ||
advantages = await compute_advantages.compute.call_one(episodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chained calls would be cool 👀
The updates boil down to 2 changes that don't alter behavior in
grpo/main
:Group
is downgraded from a dataclass to a typedef oflist[Episode]
, since it's never requiredEpisode
now directly holds aCompletion
with redundant attributes inEpisode
being removedEpisode
andScoredCompletion
.(There's also various typehint improvements sprinkled in)
Note: This PR does not address or utilize
Episode
from data_models, but convergence is imminentWandb looks roughly the same
Before: torchforge/grpo-training/runs/wca6wke2
After torchforge/grpo-training/runs/ul34xjr9