Skip to content

Conversation

@yuki-97
Copy link
Contributor

@yuki-97 yuki-97 commented Feb 2, 2026

As title.

Summary by CodeRabbit

  • Refactor

    • Optimized loss function initialization workflow by consolidating setup operations.
  • Bug Fixes

    • Added runtime validation safeguard for tensor model parallel configurations to ensure batch size consistency.

@yuki-97 yuki-97 requested review from a team as code owners February 2, 2026 06:33
@yuki-97 yuki-97 added the CI:L1 Run doctests, unit tests, and functional tests label Feb 2, 2026
@yuki-97 yuki-97 requested review from terrykong and yfw February 2, 2026 06:34
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 2, 2026

📝 Walkthrough

Walkthrough

Loss function creation and validation in the GRPO algorithm were moved from the initialization phase to the setup phase for earlier preparation. Additionally, a runtime safeguard was added to the policy model to enforce batch size consistency when using tensor model parallelism with tp_size >= 4.

Changes

Cohort / File(s) Summary
GRPO Loss Function Setup
nemo_rl/algorithms/grpo.py
Relocated loss function instantiation and force_on_policy_ratio validation from initialize_generation_with_policy to setup method. Removes duplicate initialization logic and consolidates loss-function checks earlier in the pipeline.
Policy Model Parallelism Safeguard
nemo_rl/models/policy/lm_policy.py
Added runtime assertion in Policy.__init__ to enforce that train_micro_batch_size equals logprob_batch_size when tensor model parallel size is 4 or greater, unless NRL_IGNORE_TP_ACCURACY_CHECK environment variable is set. Includes detailed remediation guidance message.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

Suggested labels

CI:L1

Suggested reviewers

  • hemildesai
  • terrykong
🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title mentions adding an assert for a tp4 batch variant accuracy issue, which aligns with the primary change in nemo_rl/models/policy/lm_policy.py where a runtime safeguard assertion is added. However, the title only partially reflects the changeset—it doesn't mention the secondary change in grpo.py involving loss function refactoring.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Test Results For Major Changes ✅ Passed PR contains minor changes: code reorganization in grpo.py and defensive validation in lm_policy.py. No algorithmic modifications, no feature additions, no numerical impact. Test documentation not required for minor changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yukih/assert-diff-bs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 303-313: Print for the force_on_policy_ratio check in grpo.py
doesn't flush output like other prints in setup(); update the print call that
outputs "  ✓ force_on_policy_ratio enabled" to include flush=True so it behaves
consistently with other setup() output (locate the block that checks
loss_config.get("force_on_policy_ratio"), which sets
os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] and currently calls print()).

In `@nemo_rl/models/policy/lm_policy.py`:
- Around line 138-151: Replace the runtime assert with an explicit conditional
that always runs: parse os.environ.get("NRL_IGNORE_TP_ACCURACY_CHECK") by
normalizing to lowercase and treating only "1", "true", or "yes" as truthy, then
if the bypass is not set and tp_size >= 4 check if
config["train_micro_batch_size"] != config["logprob_batch_size"] and raise a
RuntimeError (or ValueError) with the same multi-line message; update the block
around tp_size and the config checks in lm_policy.py (refer to tp_size,
config["train_micro_batch_size"], config["logprob_batch_size"], and the
NRL_IGNORE_TP_ACCURACY_CHECK env var) so the validation cannot be skipped by
Python -O or by setting the env var to "0"/"false".

@yuki-97 yuki-97 force-pushed the yukih/assert-diff-bs branch from 42d0467 to e10241c Compare February 2, 2026 08:08
@yuki-97 yuki-97 requested a review from a team as a code owner February 2, 2026 08:08
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 2, 2026
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 2, 2026
@yuki-97 yuki-97 requested a review from a team as a code owner February 2, 2026 11:33
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 2, 2026
Signed-off-by: Yuki Huang <[email protected]>
@yuki-97 yuki-97 force-pushed the yukih/assert-diff-bs branch from 3eb8376 to 29465d7 Compare February 2, 2026 11:35
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 2, 2026
tokenizer:
name: google/gemma-3-27b-it
train_micro_batch_size: 1
logprob_batch_size: 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about making the edit here:
https://github.com/NVIDIA-NeMo/RL/blob/main/examples/configs/grpo_math_1B.yaml#L78

and do something like:

logprob_batch_size: ${.train_micro_batch_size}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants