Skip to content

Conversation

pbontrager
Copy link
Contributor

@pbontrager pbontrager commented Sep 10, 2025

  • Updates trainer for GRPO Loss
  • Update replay buffer to do batching
  • Update rl app test to do one training step (This app is meant for testing)

The goal of this PR is to enable titan training as well as keep the data definition and the loss easily accessible to the user. The loss, data type, and collation are left to the user so they can co-design all three to their specific data and are less likely to have to touch the trainer.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 10, 2025
@pbontrager pbontrager marked this pull request as ready for review September 12, 2025 23:07
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the intention to integrate this into the GRPO app or leave that to me?

Also - how is the loss passed in ?

Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, Accept to unblock

We can follow up with a GRPO PR


def train_step(
self, inputs: list[dict[Tensor]], targets: list[dict[Tensor]]
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update return type

return tensor


def collate(batches: list[list[Episode]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type?

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete the commented out code please


print("Collecting Data...")
g = torch.manual_seed(0)
global_batch_size = cfg.replay_buffer.batch_size * cfg.replay_buffer.dp_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change right now, but wouldn't the global batch size technically be a trainer config and not a replay buffer config? Is it possible to tell omegaconf like "make this value the same as the one defined elsewhere"?

Copy link
Contributor Author

@pbontrager pbontrager Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a trainer config anymore since replay_buffer is it's own service and the trainer doesn't handle data or loading anymore

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging to unblock, please follow up on nits

@LucasLLC LucasLLC merged commit cfd9677 into main Sep 15, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants