Skip to content

feat: force on-policy ratio to 1#1529

Merged
terrykong merged 13 commits intomainfrom
yifu/force_on_policy
Dec 4, 2025
Merged

feat: force on-policy ratio to 1#1529
terrykong merged 13 commits intomainfrom
yifu/force_on_policy

Conversation

@yfw
Copy link
Copy Markdown
Contributor

@yfw yfw commented Nov 17, 2025

What does this PR do ?

Adds a flag for "forcing" on-policy ratio to 1 in the fully on-policy case (num_prompts_per_step * num_generations_per_prompt == train_global_batch_size). Original PR created by @HeyyyyyyG

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Added force_on_policy_ratio configuration option to enforce on-policy PPO behavior.
    • Introduced probability ratio min/max metrics to track policy optimization dynamics.
  • Chores

    • Updated configuration files and tests to support the new on-policy ratio control option.

✏️ Tip: You can customize this high-level summary in your review settings.

yfw and others added 9 commits November 12, 2025 18:34
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Prevents incorrect dp size in parallel_state during initial import.

Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Add packed_seq_params change to get_topk_logits too

Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com>
@yfw yfw changed the title force on-policy ratio to 1 feat: force on-policy ratio to 1 Nov 17, 2025
Base automatically changed from yifu/nano-v2-main to main November 18, 2025 05:57
@yfw yfw added the CI:L1 Run doctests, unit tests, and functional tests label Dec 2, 2025
@yfw yfw added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 2, 2025
@yfw yfw marked this pull request as ready for review December 2, 2025 07:33
@yfw yfw requested review from a team as code owners December 2, 2025 07:33
@yfw yfw added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 2, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 2, 2025

📝 Walkthrough

Walkthrough

A new configuration option force_on_policy_ratio (default false) is added across configuration files and implemented in core loss and training logic. When enabled, the PPO ratio is forced to 1.0 with the prerequisite that train_global_batch_size equals num_prompts_per_step * num_generations_per_prompt. Implementation includes initialization validation, min/max metrics tracking for probability ratios, and unit tests.

Changes

Cohort / File(s) Summary
Configuration files
examples/configs/grpo_math_1B.yaml, examples/configs/vlm_grpo_3B.yaml, examples/configs/vlm_grpo_3B_megatron.yaml, examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml
Added force_on_policy_ratio: false configuration field under loss_fn with comment noting prerequisite constraint on batch sizes.
Core loss function logic
nemo_rl/algorithms/loss_functions.py
Introduced force_on_policy_ratio toggle in ClippedPGLoss to force PPO ratio to 1.0 when enabled. Added min/max tracking metrics for probability ratios (probs_ratio_min, probs_ratio_max, probs_ratio_clamped_min, probs_ratio_clamped_max) with infinite-value filtering. Extended SequencePackingLossWrapper metric aggregation logic. Added math import for safe infinity handling.
GRPO training logic
nemo_rl/algorithms/grpo.py
Added validation after loss function initialization to ensure force_on_policy_ratio is only enabled when train_global_batch_size equals num_prompts_per_step * num_generations_per_prompt. Extended metrics aggregation in both grpo_train and async_grpo_train to compute min/max for ratio-related keys with infinite-value filtering.
Unit tests
tests/unit/algorithms/test_grpo.py, tests/unit/algorithms/test_loss_functions.py, tests/unit/algorithms/test_sequence_packing_gradients.py, tests/unit/models/policy/test_dtensor_worker.py, tests/unit/models/policy/test_megatron_worker.py
Added force_on_policy_ratio: False to ClippedPGLossFn configuration in test fixtures. Added new unit test test_clipped_pg_loss_force_on_policy_ratio validating ratio forcing and metrics propagation.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • nemo_rl/algorithms/loss_functions.py: Contains the core logic for force_on_policy_ratio enforcement and new metrics tracking; requires careful review of the ratio computation path, infinite-value handling, and metric aggregation logic.
  • nemo_rl/algorithms/grpo.py: Validation logic and metrics aggregation changes across two training functions; verify constraint checking is properly applied in both code paths.
  • tests/unit/algorithms/test_loss_functions.py: New test validates forced ratio behavior; verify synthetic data setup and expected loss computation are correct.

Possibly related PRs

Suggested reviewers

  • parthchadha
  • ashors1

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning PR introduces significant GRPO algorithm feature but lacks documented test results and validation evidence in description. Add test results section documenting unit test passes, new test outcomes, regression testing with flag disabled, and performance metrics confirming feature works without degradation.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: force on-policy ratio to 1' directly summarizes the main change across all modified files - adding a force_on_policy_ratio configuration option to force PPO ratios to 1.0.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yifu/force_on_policy

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/algorithms/loss_functions.py (1)

36-61: Fix duplicated force_on_policy_ratio key in ClippedPGLossConfig

ClippedPGLossConfig currently declares force_on_policy_ratio twice: once as a required bool and once as NotRequired[bool] with the explanatory comment. Type checkers will flag this, and the intent (optional flag with documented semantics) matches the second declaration.

Recommend keeping a single optional entry:

 class ClippedPGLossConfig(TypedDict):
@@
-    token_level_loss: bool
-    force_on_policy_ratio: bool
+    token_level_loss: bool
@@
-    disable_ppo_ratio: NotRequired[bool]
-    # If True, force the ratio to 1.0 for truly on-policy behavior,
+    disable_ppo_ratio: NotRequired[bool]
+    # If True, force the ratio to 1.0 for truly on-policy behavior,
@@
-    force_on_policy_ratio: NotRequired[bool]
+    force_on_policy_ratio: NotRequired[bool]

That keeps the docs and avoids inconsistent typing.

🧹 Nitpick comments (2)
nemo_rl/algorithms/grpo.py (1)

585-595: On-policy ratio validation looks correct but is quite strict

The assertion enforces the documented requirement that force_on_policy_ratio only be used when train_global_batch_size == num_prompts_per_step * num_generations_per_prompt, which is good for catching misconfigurations early. Note that this will also forbid using the flag with async GRPO even if users think the config is “close enough,” since async inherently reuses trajectories; if that’s intentional, consider adding a brief clarifying note in docs or config comments so users don’t try to combine the two.

nemo_rl/algorithms/loss_functions.py (1)

117-155: force_on_policy_ratio ctor wiring is correct; consider guarding incompatible modes

The ctor wiring of self.force_on_policy_ratio = cfg.get("force_on_policy_ratio", False) is consistent with existing flags and lets older configs omit the key. Behavior‑wise it works, but note that you now have three mutually interacting switches (disable_ppo_ratio, sequence_level_importance_ratios, force_on_policy_ratio). It may be worth asserting that obviously incompatible combinations (e.g., disable_ppo_ratio=True and force_on_policy_ratio=True) are rejected early rather than producing confusing metrics.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cff17f8 and 50a72d2.

📒 Files selected for processing (11)
  • examples/configs/grpo_math_1B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B_megatron.yaml (1 hunks)
  • examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml (1 hunks)
  • nemo_rl/algorithms/grpo.py (3 hunks)
  • nemo_rl/algorithms/loss_functions.py (8 hunks)
  • tests/unit/algorithms/test_grpo.py (1 hunks)
  • tests/unit/algorithms/test_loss_functions.py (2 hunks)
  • tests/unit/algorithms/test_sequence_packing_gradients.py (1 hunks)
  • tests/unit/models/policy/test_dtensor_worker.py (1 hunks)
  • tests/unit/models/policy/test_megatron_worker.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • examples/configs/grpo_math_1B.yaml
  • examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml
  • examples/configs/vlm_grpo_3B_megatron.yaml
  • tests/unit/models/policy/test_dtensor_worker.py
  • examples/configs/vlm_grpo_3B.yaml
  • tests/unit/algorithms/test_sequence_packing_gradients.py
  • tests/unit/models/policy/test_megatron_worker.py
  • tests/unit/algorithms/test_grpo.py
  • nemo_rl/algorithms/grpo.py
  • tests/unit/algorithms/test_loss_functions.py
  • nemo_rl/algorithms/loss_functions.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • tests/unit/models/policy/test_dtensor_worker.py
  • tests/unit/algorithms/test_sequence_packing_gradients.py
  • tests/unit/models/policy/test_megatron_worker.py
  • tests/unit/algorithms/test_grpo.py
  • nemo_rl/algorithms/grpo.py
  • tests/unit/algorithms/test_loss_functions.py
  • nemo_rl/algorithms/loss_functions.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • tests/unit/models/policy/test_dtensor_worker.py
  • tests/unit/algorithms/test_sequence_packing_gradients.py
  • tests/unit/models/policy/test_megatron_worker.py
  • tests/unit/algorithms/test_grpo.py
  • nemo_rl/algorithms/grpo.py
  • tests/unit/algorithms/test_loss_functions.py
  • nemo_rl/algorithms/loss_functions.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/algorithms/loss_functions.py
🧠 Learnings (1)
📚 Learning: 2025-09-18T13:26:43.307Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:19-26
Timestamp: 2025-09-18T13:26:43.307Z
Learning: In on-policy distillation workflows, validation can use downstream task performance (like math problem solving) as RL-like reward metrics rather than traditional distillation metrics like KL divergence. In this case, "val_reward" with "higher_is_better: true" is the correct checkpoint monitoring configuration.

Applied to files:

  • examples/configs/vlm_grpo_3B_megatron.yaml
🧬 Code graph analysis (2)
nemo_rl/algorithms/grpo.py (1)
tests/check_metrics.py (2)
  • min (25-27)
  • max (30-32)
tests/unit/algorithms/test_loss_functions.py (1)
nemo_rl/algorithms/loss_functions.py (1)
  • ClippedPGLossFn (76-478)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: sphinx-build / Build docs
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (14)
examples/configs/vlm_grpo_3B.yaml (1)

34-49: Loss config flag addition is consistent and safely defaulted

force_on_policy_ratio: false is added with a clear comment, and this config already satisfies the documented prerequisite (128 = 8 * 16), so the flag can be safely toggled later without further changes.

examples/configs/grpo_math_1B.yaml (1)

38-56: force_on_policy_ratio flag wired correctly into math GRPO config

The new force_on_policy_ratio: false field and its comment are consistent with other configs, and the batch-size relation (512 = 32 * 16) already meets the stated requirement for enabling it.

examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml (1)

31-45: Penguin GRPO recipe aligns flag and batch-size constraint

Adding force_on_policy_ratio: false with the explanatory comment is consistent with other recipes, and computing train_global_batch_size directly from num_prompts_per_step * num_generations_per_prompt guarantees the documented prerequisite holds when this flag is enabled.

tests/unit/models/policy/test_dtensor_worker.py (1)

667-685: DTensor PG loss test config correctly extended with force_on_policy_ratio

Including "force_on_policy_ratio": False in the ClippedPGLossFn config keeps this test aligned with the updated loss configuration while preserving its original behavior (microbatch-size invariance under standard PPO ratios).

examples/configs/vlm_grpo_3B_megatron.yaml (1)

31-45: Megatron VLM GRPO config cleanly integrates force_on_policy_ratio

The new force_on_policy_ratio: false option and its comment match other configs, and the batch-size relation (128 = 8 * 16) satisfies the documented requirement if users choose to enable this behavior.

tests/unit/algorithms/test_sequence_packing_gradients.py (1)

129-145: Sequence-packing gradient test loss_config updated appropriately

Adding "force_on_policy_ratio": False to loss_config ensures compatibility with the extended ClippedPGLossFn configuration without affecting the baseline vs packed gradient comparisons this test asserts.

tests/unit/algorithms/test_grpo.py (1)

886-902: GRPO unit tests now pass full ClippedPG loss config

The addition of "force_on_policy_ratio": False to the ClippedPGLossFn configuration keeps these GRPO tests aligned with the expanded loss options, while preserving existing training-loop behavior in both synchronous and async modes.

tests/unit/models/policy/test_megatron_worker.py (1)

39-54: Megatron basic_pg_loss_test_config extended with force_on_policy_ratio

Updating basic_pg_loss_test_config to include "force_on_policy_ratio": False correctly aligns all Megatron PG-loss tests with the updated ClippedPGLossConfig while maintaining their existing expectations around ratios and gradients.

nemo_rl/algorithms/grpo.py (2)

1355-1365: Ratio min/max aggregation is reasonable; just ensure upstream always passes iterables

The new handling for probs_ratio_min/max and probs_ratio_clamped_min/max (filtering out inf and falling back to -1.0 when all values are invalid) matches the sentinel semantics from the loss and should keep logs sane even when some microbatches/seqs have no valid tokens. This loop assumes v is iterable (e.g., NumPy array, list); if any training path ever supplies a scalar here, for x in v will fail. Worth double‑checking that train_results["all_mb_metrics"] always uses array‑like containers.


2294-2304: Mirror of ratio min/max aggregation in async path

The async GRPO metrics reduction mirrors the sync path and correctly reuses the same inf‑filtering semantics for the new ratio metrics. Same caveat applies: this assumes each metric value is iterable rather than a bare scalar.

nemo_rl/algorithms/loss_functions.py (3)

311-335: On‑policy ratio forcing implementation matches intent

The new force_on_policy_ratio branch (log_ratios = curr_logprobs - curr_logprobs.detach()) correctly yields ratios = 1.0 in the forward pass while keeping non‑zero gradients w.r.t. curr_logprobs, so the actor loss behaves as “truly on‑policy PPO” without referencing prev_logprobs. Clamping is effectively a no‑op in this mode, which is exactly what you want. No issues here.


429-455: Ratio min/max metrics and sentinels look good

The min/max ratio metrics:

  • Compute over valid tokens only (mask.bool()), in line with how the mean ratios are computed.
  • Use ±inf sentinels when there are no valid tokens so downstream aggregation can distinguish “no data” from real extremes.

This design matches the aggregation logic in grpo.py and SequencePackingLossWrapper; just be aware that any consumer that doesn’t implement the same sentinel filtering will see infinities for completely masked batches.


944-961: SequencePackingLossWrapper aggregation correctly handles new ratio metrics

Initializing probs_ratio_min/_clamped_min with +inf and the max variants with -inf, then skipping math.isinf(val) when updating, ensures that:

  • Sequences with no valid tokens (which emit ±inf from the base loss) don’t influence the packed aggregate.
  • “True” min/max values across packed sequences are preserved.

For all other metrics you keep the existing sum‑across‑sequences behavior. This is a clean extension for the new metrics.

tests/unit/algorithms/test_loss_functions.py (1)

30-45: Good coverage for the on‑policy ratio path

Extending basic_pg_loss_test_config with force_on_policy_ratio and adding test_clipped_pg_loss_force_on_policy_ratio gives solid unit coverage of the new mode: you reuse the existing PPO scenario, assert the actor loss matches the “all ratios = 1.0” expectation, and verify all ratio‑related metrics report exactly 1. This should catch regressions both in the ratio forcing path and in the new min/max metrics.

Also applies to: 566-619

Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
@yfw yfw added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 2, 2025
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very clever optimization!

@terrykong terrykong enabled auto-merge (squash) December 4, 2025 07:15
@terrykong terrykong merged commit 1cad374 into main Dec 4, 2025
43 of 44 checks passed
@terrykong terrykong deleted the yifu/force_on_policy branch December 4, 2025 07:15
DeL-TaiseiOzaki pushed a commit to DeL-TaiseiOzaki/RL that referenced this pull request Jan 8, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com>
Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com>
Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com>
Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com>
Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 9, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com>
Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants