-
Notifications
You must be signed in to change notification settings - Fork 17
Off-by-1 GRPO #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Off-by-1 GRPO #140
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Just fyi the link gives me 404 |
https://wandb.ai/jcummings/grpo-training/workspace?nw=nwuserjcummings Try now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very important to get in, really thanks for this! I added a comment on how I think we should handle the policy update logic though.
return | ||
prompt, target = sample["request"], sample["target"] | ||
version = 0 # await policy.get_current_version.choose() | ||
responses = await policy.generate.choose(prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll throw away a lot of data this way for fully on policy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For short responses yeah definitely, if you look at the WandB logs (buffer_size/rollout), you can see that we build up a buffer of about 100 episodes and then evict the majority of them back and forth during weight updates.
When we start allowing much longer generations and our models are much bigger, this won't be as big of an issue.
apps/grpo/main.py
Outdated
logger.log("loss/training_step", loss_value, training_step) | ||
# await trainer.update_weights(policy) | ||
logger.log("loss/training_step", loss, training_step) | ||
await trainer.push_weights.choose(policy_version) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be call technically even though choose works since replicas=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be call technically even though choose works since replicas=1
As a side note, even if we do call()
, what we are doing here is all the trainers training on the same batch right?
The replicas are just for fault tolerance? In this regard, if we want different trainers to train on different batches, the trainers themselves have to pull the batches right?
An alternative way to do this is we split the batch
into microbatch
es and then call choose()
on each microbatch
. After the whole batch
is done. We then do an all_reduce
(or other forms of reduce) to average the weights.
await asyncio.sleep(0.1) | ||
else: | ||
training_result = await trainer.train_step.choose(batch) | ||
loss = await trainer.train_step.choose(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be a call
Saw the graphs 📈 nice!! it's working! |
apps/grpo/main.py
Outdated
project="grpo-training", | ||
) | ||
|
||
store = await MultiProcessStore.create_store() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasLLC Is this still the recommended way of doing things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we are using ts.initialize()
now and there is a global singleton torchstore. But I will let @LucasLLC weigh in.
f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}" | ||
) | ||
|
||
key = f"{self.state_dict_key}{DELIM}{version}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this a function since the caller also uses it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely make this a function, and we should do f"{version}{DELIM}{self.state_dict_key}"
instead.
Which is the correct hierarchy we should use given the new keys(prefix)
in torchstore api.
apps/grpo/main.py
Outdated
logger.log("loss/training_step", loss_value, training_step) | ||
# await trainer.update_weights(policy) | ||
logger.log("loss/training_step", loss, training_step) | ||
await trainer.push_weights.choose(policy_version) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be call technically even though choose works since replicas=1
As a side note, even if we do call()
, what we are doing here is all the trainers training on the same batch right?
The replicas are just for fault tolerance? In this regard, if we want different trainers to train on different batches, the trainers themselves have to pull the batches right?
An alternative way to do this is we split the batch
into microbatch
es and then call choose()
on each microbatch
. After the whole batch
is done. We then do an all_reduce
(or other forms of reduce) to average the weights.
f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}" | ||
) | ||
|
||
key = f"{self.state_dict_key}{DELIM}{version}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely make this a function, and we should do f"{version}{DELIM}{self.state_dict_key}"
instead.
Which is the correct hierarchy we should use given the new keys(prefix)
in torchstore api.
apps/grpo/main.py
Outdated
project="grpo-training", | ||
) | ||
|
||
store = await MultiProcessStore.create_store() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we are using ts.initialize()
now and there is a global singleton torchstore. But I will let @LucasLLC weigh in.
apps/grpo/main.py
Outdated
beta: float = 0.1 | ||
epsilon: float = 0.1 | ||
device: torch.device | None = None | ||
store: MultiProcessStore | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't the current recommended way to use store now. You should just call it as a singleton inside of the trainer
apps/grpo/main.py
Outdated
|
||
self.logger.info(f"Trainer model initialized on {self.device}") | ||
|
||
def _qwen3_hf_to_vllm(self, saved_sd): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put this in the trainer.py file? We'll need to reuse this with titan and this will merge nicely with Pradeep's PR
apps/grpo/main.py
Outdated
await asyncio.sleep(0.1) | ||
else: | ||
training_result = await trainer.train_step.choose(batch) | ||
loss = sum(await trainer.train_step.call(batch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be returning a list to sum over right? Their should be 1 replica for this service?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Service call currently returns a list regardless of num_replicas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wrong earlier, this should be a choose call. I think only policy.update_weights should be a call
asyncio.run(main(cfg)) | ||
if __name__ == "__main__": | ||
|
||
@parse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this change?
|
||
@endpoint | ||
async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: | ||
async def generate(self, prompt: str, priority: int = 0) -> RequestOutput: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this necessary? This pattern is never great for readability
start = time.time() | ||
await self._load_tensor_parallel_state_dict(current_state_dict, version) | ||
logger.debug("Successfully updated model weights from torchstore") | ||
self.logger.debug( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to log this to wandb then spam the terminal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to do self.logger
instead of just logger
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the creation of ForgeActor, each actor has it's own logger and should use that (b/c it logs information about which actor is logging) instead of a global logger.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, the ForgeActor logger was implemented in a way where it's supposed to just work with logger
(ie no need to self.logger
) but please let me know if that doesn't work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally looks good, if you could just do me a favor and add some of these issue tags!
apps/grpo/main.py
Outdated
pass | ||
async def push_weights(self, version: int): | ||
"""Update policy model weights with trainer's current weights.""" | ||
start_time = time.time() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start_time = time.time() | |
# TODO - issues/148 followup | |
start_time = time.time() |
Just for my future reference, tagging some pieces for observability
|
||
@endpoint | ||
async def compute(self, group: Group) -> list[float]: | ||
# TODO: add batch processing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# TODO: add batch processing | |
# TODO: issues/120 add batch processing |
model = self.worker.model_runner.model | ||
current_state_dict = model.state_dict() | ||
|
||
start = time.time() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start = time.time() | |
# TODO - issues/148 | |
start = time.time() |
start = time.time() | ||
await self._load_tensor_parallel_state_dict(current_state_dict, version) | ||
logger.debug("Successfully updated model weights from torchstore") | ||
self.logger.debug( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to do self.logger
instead of just logger
?
apps/grpo/main.py
Outdated
await asyncio.sleep(0.1) | ||
else: | ||
training_result = await trainer.train_step.choose(batch) | ||
loss = sum(await trainer.train_step.call(batch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Service call currently returns a list regardless of num_replicas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for going through so much detail. Please remove the "_functions" but good to go.
apps/grpo/main.py
Outdated
await asyncio.sleep(0.1) | ||
else: | ||
training_result = await trainer.train_step.choose(batch) | ||
loss = sum(await trainer.train_step.call(batch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wrong earlier, this should be a choose call. I think only policy.update_weights should be a call
self.engine.checkpointer.close() | ||
|
||
|
||
def _qwen3_hf_to_vllm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be for the next PR: Thoughts on having this in the generator/policy side? Ideally trainer should be agnostic to how the poilicy packs/transforms certain tensors.
"""Compute math correctness reward.""" | ||
# Parse expected | ||
expected_answer = self._to_float(target) | ||
target_number = self._to_float(target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: Do we a plan to have math verifier libraries which provides certain things (ExactMatch, partial match etc..) out of the box?
https://wandb.ai/jcummings/grpo-training/workspace?nw=nwuserjcummings