-
Notifications
You must be signed in to change notification settings - Fork 16
Skeleton of GRPO #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skeleton of GRPO #58
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
training_step = 0 | ||
while True: | ||
batch = await replay_buffer.sample.choose(curr_policy_version=0) | ||
if batch is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this should just be inside of the buffer.sample method. Then you'll just await until there's enough data for a sample.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I disagree. The contract here is just that the buffer will return a sample here when it computationally can, not when there's something inside.
In addition, this would push possible errors a layer down if the replay buffer isn't getting filled for some reason. I'd rather have this be exposed logic to the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does sample return None? Is it when number of usable rollouts < batch_size? I feel like this is going to surprise some people who try to use this and start running into Nones in their batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah if the replay buffer has nothing in it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think exposing logic to the user makes sense, my only nit is that I might want it to have more of a queue-like semantic? Like we set a timeout, it raises an error if we hit the timeout etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does choose
have a built-in timeout feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't, but the buffer itself can return an exception which will be propagated through choose
but the current service implementation will then mark the replica as unhealthy so we shouldn't add this yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bunch of initial questions -- I understand this is not really meant to be the final form of the GRPO loop so let me know which gaps are deliberate hacks to unblock vs ones that require more thought. (Ideally just put this in the PR summary to be explicit.)
Would also be good to start factoring out some stuff like rewards, actors, etc into separate files so that main.py starts looking closer to a pure training script.
) # Remove batch dimension for single response | ||
|
||
|
||
class DatasetActor(ForgeActor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Longer-term what is our plan here? (I.e. is our HfIterableDataset compatible with how we're setting things up here?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to re-build up our HfIterableDataset based on the patterns we observe when testing out this training loop. Right now, it's not super intuitive to use and it's also not clear what we actually need from there.
My guess is that we'll keep 75% of the current implementation, but I don't want to assume before we actually start testing.
ds = ds.map(gsm8k_to_messages) | ||
ds = ds.shuffle() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice no split_dataset_by_node
. Is this by design?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, just missing it. Will add.
apps/grpo/main.py
Outdated
avg_reward = sum(group.reward for group in episode.groups) / len( | ||
episode.groups | ||
) | ||
wandb.log({"rollout_count": rollout_count, "avg_reward": avg_reward}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this eventually be its own actor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it needs to be unless we're making so many requests from all over to log things and we want it going through a central source.
Otherwise, we can just treat the logger as a normal component.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would wandb complain about having multiple "sessions" created from the same job? If we can minimize the messages being passed in Monarch that generally seems like a good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want it to be it's own actor so logs are called from within actors too. It's a lot of unnecessary data passing otherwise, especially if people log artifacts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I'm not sure if I understand that comment @pbontrager
if the metrics logger is an actor, then all actors would need to pass messages over Monarch to the metrics actor right? If actors can just write directly to wandb that would minimize data passing?
self.lambda_ = lambda_ # GAE lambda parameter | ||
|
||
@endpoint | ||
async def __call__(self, groups: list[Group]) -> list[float]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this based off of some specific reference implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My brain, the old impl of GRPO in torchtune, and good ol devmate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I trust that consensus..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're probably not going to want a Group type, but we put the type work on the back burner for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should spin up a track fully dedicated to ensuring numerical correctness (which I think is the basis of @ebsmothers comment), but imo the right timing for this is after we have the working prototype
@endpoint | ||
async def add(self, trajectory: Trajectory) -> None: | ||
self.buffer.append(trajectory) | ||
async def add(self, episode) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we are moving from Trajectories to Episodes in the replay buffer? Isn't this overly-tailored to GRPO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything in here is overly tailored to GRPO. I want to base our design decisions from real life, otherwise we can talk all day. This is definitely not the final state of our APIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Episode isn't GRPO specific, it's basically another word for Trajectory.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is awesome, great work @joecummings!
self.prompt = prompt | ||
self.target = target | ||
self.policy_version = policy_version | ||
self.groups: list[Group] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming nit, but isn't what we're calling Group
here actually an individual output, and what we're calling self.groups
the actual Group
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair lol
ref_logprobs_list = [] | ||
advantages_list = [] | ||
|
||
for group in groups: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no action needed now, but imo it would be nifty for the episode structure itself to handle things like this
for episode in batch:
tokenized = self.tokenizer(
episode.as_response_group(),
...
)
training_step = 0 | ||
while True: | ||
batch = await replay_buffer.sample.choose(curr_policy_version=0) | ||
if batch is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think exposing logic to the user makes sense, my only nit is that I might want it to have more of a queue-like semantic? Like we set a timeout, it raises an error if we hit the timeout etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really great! There will be a lot of refactor work going forward so I won't push for changes here.
@endpoint | ||
async def setup(self): | ||
# Set up policy_worker | ||
self.available_devices = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should really come from the script and not the config. Ideally the script makes the master set of ranks and then passes them to the actors.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
What this PR DOES do:
Core GRPO Implementation:
Service Infrastructure:
What this PR will NOT cover:
Missing Features:
# await trainer.update_weights(policy)
) @pradeepfn @pbontragerIncomplete Work:
**main.py**
) @Jack-KhuuTechnical Limitations:
This is a working prototype of GRPO with the core RL loop functional but requiring additional work for production deployment, proper weight updates, and comprehensive logging.