-
Notifications
You must be signed in to change notification settings - Fork 24
Rithesh/toy app #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rithesh/toy app #175
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this looks great and makes huge strides toward correctness in Forge. Just a couple of comments.
src/forge/types.py
Outdated
|
||
|
||
Scalar = Union[int, float] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't say we're confident that these Episode and Group abstractions are the best ones yet - I'd be more comfortable if you just copy-pasta'd them into the sumdigits.py file in order to use them for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vidhyav is rolling out the abstractions soon. Jut centralizing this so that he just has 1 place to fix.
let me know if you still wish for me to copy paste them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still prefer a copy paste if that's alright? Sorry for being a stickler :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. copy pasted the code and added a TODO.
apps/toy_rl/sumdigits.py
Outdated
mlogger.log("loss/training_step", loss, training_step) | ||
print(f"loss/training_step: {loss} at {training_step}") | ||
if training_step % 5 == 0: | ||
await trainer.push_weights.call(policy_version) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weight sync is off by 5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. because this is a toy app and the weight sync take a long time. :)
Ideally I wish for us to have a accumulate and apply gradients abstractions so that we can just accumulate the gradients and apply them after every N batches (in this case 5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense - just curious, how much faster does it converge when weight sync is just off by 1 via the replay buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here is the off-by-1 run. https://meta.wandb.io/torchforge/sumdigits-training/runs/wblr9xh7?nw=nwuserrithesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to be on-policy. we can figure out what's best later when we are setting up the CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@Ritesh1905 , based on your experience, is this a regression? https://meta.wandb.io/jiyue/sumdigits-training/runs/ah2js4mb?nw=nwuserjiyue |
A simple toy app RL loop that (almost) converges in less than 5 mins. This uses a much simpler reinforce loss. I could not get the reward-mean converging with the GRPO loss. Sending this PR here to get early feedback and once this makes sense, I will figure out to make it work with GRPO loss.
https://meta.wandb.io/rithesh/sumdigits-training/runs/kmj952x7?nw=nwuserrithesh