Skip to content

Conversation

joecummings
Copy link
Member

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 8, 2025
@joecummings joecummings changed the title Working updates On-policy GRPO Sep 8, 2025
@casteryh
Copy link
Contributor

casteryh commented Sep 9, 2025

Just fyi the link gives me 404

@joecummings
Copy link
Member Author

Just fyi the link gives me 404

https://wandb.ai/jcummings/grpo-training/workspace?nw=nwuserjcummings

Try now?

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very important to get in, really thanks for this! I added a comment on how I think we should handle the policy update logic though.

return
prompt, target = sample["request"], sample["target"]
version = 0 # await policy.get_current_version.choose()
responses = await policy.generate.choose(prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll throw away a lot of data this way for fully on policy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For short responses yeah definitely, if you look at the WandB logs (buffer_size/rollout), you can see that we build up a buffer of about 100 episodes and then evict the majority of them back and forth during weight updates.

When we start allowing much longer generations and our models are much bigger, this won't be as big of an issue.

logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)
logger.log("loss/training_step", loss, training_step)
await trainer.push_weights.choose(policy_version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be call technically even though choose works since replicas=1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be call technically even though choose works since replicas=1

As a side note, even if we do call(), what we are doing here is all the trainers training on the same batch right?

The replicas are just for fault tolerance? In this regard, if we want different trainers to train on different batches, the trainers themselves have to pull the batches right?

An alternative way to do this is we split the batch into microbatches and then call choose() on each microbatch. After the whole batch is done. We then do an all_reduce (or other forms of reduce) to average the weights.

await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
loss = await trainer.train_step.choose(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be a call

@casteryh
Copy link
Contributor

casteryh commented Sep 9, 2025

Just fyi the link gives me 404

https://wandb.ai/jcummings/grpo-training/workspace?nw=nwuserjcummings

Try now?

Saw the graphs 📈 nice!! it's working!

project="grpo-training",
)

store = await MultiProcessStore.create_store()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasLLC Is this still the recommended way of doing things?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we are using ts.initialize() now and there is a global singleton torchstore. But I will let @LucasLLC weigh in.

f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}"
)

key = f"{self.state_dict_key}{DELIM}{version}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a function since the caller also uses it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely make this a function, and we should do f"{version}{DELIM}{self.state_dict_key}" instead.
Which is the correct hierarchy we should use given the new keys(prefix) in torchstore api.

logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)
logger.log("loss/training_step", loss, training_step)
await trainer.push_weights.choose(policy_version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be call technically even though choose works since replicas=1

As a side note, even if we do call(), what we are doing here is all the trainers training on the same batch right?

The replicas are just for fault tolerance? In this regard, if we want different trainers to train on different batches, the trainers themselves have to pull the batches right?

An alternative way to do this is we split the batch into microbatches and then call choose() on each microbatch. After the whole batch is done. We then do an all_reduce (or other forms of reduce) to average the weights.

f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}"
)

key = f"{self.state_dict_key}{DELIM}{version}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely make this a function, and we should do f"{version}{DELIM}{self.state_dict_key}" instead.
Which is the correct hierarchy we should use given the new keys(prefix) in torchstore api.

project="grpo-training",
)

store = await MultiProcessStore.create_store()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we are using ts.initialize() now and there is a global singleton torchstore. But I will let @LucasLLC weigh in.

@joecummings joecummings changed the title On-policy GRPO Off-policy GRPO Sep 11, 2025
@joecummings joecummings changed the title Off-policy GRPO Off-by-1 GRPO Sep 11, 2025
beta: float = 0.1
epsilon: float = 0.1
device: torch.device | None = None
store: MultiProcessStore | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the current recommended way to use store now. You should just call it as a singleton inside of the trainer


self.logger.info(f"Trainer model initialized on {self.device}")

def _qwen3_hf_to_vllm(self, saved_sd):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this in the trainer.py file? We'll need to reuse this with titan and this will merge nicely with Pradeep's PR

await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
loss = sum(await trainer.train_step.call(batch))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be returning a list to sum over right? Their should be 1 replica for this service?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Service call currently returns a list regardless of num_replicas

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong earlier, this should be a choose call. I think only policy.update_weights should be a call

asyncio.run(main(cfg))
if __name__ == "__main__":

@parse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change?


@endpoint
async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]:
async def generate(self, prompt: str, priority: int = 0) -> RequestOutput:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this necessary? This pattern is never great for readability

start = time.time()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.debug("Successfully updated model weights from torchstore")
self.logger.debug(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to log this to wandb then spam the terminal

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do self.logger instead of just logger?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the creation of ForgeActor, each actor has it's own logger and should use that (b/c it logs information about which actor is logging) instead of a global logger.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, the ForgeActor logger was implemented in a way where it's supposed to just work with logger (ie no need to self.logger) but please let me know if that doesn't work

Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally looks good, if you could just do me a favor and add some of these issue tags!

pass
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
start_time = time.time()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
start_time = time.time()
# TODO - issues/148 followup
start_time = time.time()

Just for my future reference, tagging some pieces for observability


@endpoint
async def compute(self, group: Group) -> list[float]:
# TODO: add batch processing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: add batch processing
# TODO: issues/120 add batch processing

model = self.worker.model_runner.model
current_state_dict = model.state_dict()

start = time.time()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
start = time.time()
# TODO - issues/148
start = time.time()

start = time.time()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.debug("Successfully updated model weights from torchstore")
self.logger.debug(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do self.logger instead of just logger?

await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
loss = sum(await trainer.train_step.call(batch))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Service call currently returns a list regardless of num_replicas

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for going through so much detail. Please remove the "_functions" but good to go.

await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
loss = sum(await trainer.train_step.call(batch))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong earlier, this should be a choose call. I think only policy.update_weights should be a call

@joecummings joecummings merged commit a6ca591 into meta-pytorch:main Sep 12, 2025
5 checks passed
self.engine.checkpointer.close()


def _qwen3_hf_to_vllm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be for the next PR: Thoughts on having this in the generator/policy side? Ideally trainer should be agnostic to how the poilicy packs/transforms certain tensors.

"""Compute math correctness reward."""
# Parse expected
expected_answer = self._to_float(target)
target_number = self._to_float(target)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: Do we a plan to have math verifier libraries which provides certain things (ExactMatch, partial match etc..) out of the box?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants