-
Notifications
You must be signed in to change notification settings - Fork 234
chore: add assert for tp4 batch variant accuracy issue #1861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughLoss function creation and validation in the GRPO algorithm were moved from the initialization phase to the setup phase for earlier preparation. Additionally, a runtime safeguard was added to the policy model to enforce batch size consistency when using tensor model parallelism with tp_size >= 4. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~15 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 303-313: Print for the force_on_policy_ratio check in grpo.py
doesn't flush output like other prints in setup(); update the print call that
outputs " ✓ force_on_policy_ratio enabled" to include flush=True so it behaves
consistently with other setup() output (locate the block that checks
loss_config.get("force_on_policy_ratio"), which sets
os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] and currently calls print()).
In `@nemo_rl/models/policy/lm_policy.py`:
- Around line 138-151: Replace the runtime assert with an explicit conditional
that always runs: parse os.environ.get("NRL_IGNORE_TP_ACCURACY_CHECK") by
normalizing to lowercase and treating only "1", "true", or "yes" as truthy, then
if the bypass is not set and tp_size >= 4 check if
config["train_micro_batch_size"] != config["logprob_batch_size"] and raise a
RuntimeError (or ValueError) with the same multi-line message; update the block
around tp_size and the config checks in lm_policy.py (refer to tp_size,
config["train_micro_batch_size"], config["logprob_batch_size"], and the
NRL_IGNORE_TP_ACCURACY_CHECK env var) so the validation cannot be skipped by
Python -O or by setting the env var to "0"/"false".
Signed-off-by: Yuki Huang <[email protected]>
Signed-off-by: Yuki Huang <[email protected]>
Signed-off-by: Yuki Huang <[email protected]>
42d0467 to
e10241c
Compare
Signed-off-by: Yuki Huang <[email protected]>
3eb8376 to
29465d7
Compare
| tokenizer: | ||
| name: google/gemma-3-27b-it | ||
| train_micro_batch_size: 1 | ||
| logprob_batch_size: 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about making the edit here:
https://github.com/NVIDIA-NeMo/RL/blob/main/examples/configs/grpo_math_1B.yaml#L78
and do something like:
logprob_batch_size: ${.train_micro_batch_size}
As title.
Summary by CodeRabbit
Refactor
Bug Fixes