-
Notifications
You must be signed in to change notification settings - Fork 18
Trainer Loss #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainer Loss #146
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the intention to integrate this into the GRPO app or leave that to me?
Also - how is the loss passed in ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, Accept to unblock
We can follow up with a GRPO PR
|
||
def train_step( | ||
self, inputs: list[dict[Tensor]], targets: list[dict[Tensor]] | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update return type
return tensor | ||
|
||
|
||
def collate(batches: list[list[Episode]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete the commented out code please
|
||
print("Collecting Data...") | ||
g = torch.manual_seed(0) | ||
global_batch_size = cfg.replay_buffer.batch_size * cfg.replay_buffer.dp_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to change right now, but wouldn't the global batch size technically be a trainer config and not a replay buffer config? Is it possible to tell omegaconf like "make this value the same as the one defined elsewhere"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a trainer config anymore since replay_buffer is it's own service and the trainer doesn't handle data or loading anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging to unblock, please follow up on nits
The goal of this PR is to enable titan training as well as keep the data definition and the loss easily accessible to the user. The loss, data type, and collation are left to the user so they can co-design all three to their specific data and are less likely to have to touch the trainer.