feat: ProRLv2 - add seq-mask-tis truncated importance sampling type#1899
feat: ProRLv2 - add seq-mask-tis truncated importance sampling type#1899terrykong merged 8 commits intoNVIDIA-NeMo:mainfrom
Conversation
Add a new IS filtering mechanism "seq-mask-tis" that masks entire sequences based on the geometric mean of per-token IS ratios, while keeping non-truncated token-level IS weights for gradient correction. Also adds shared `is_filter_drop_frac` metric for both icepop and seq-mask-tis modes, and documents the new option in prorlv2.md. Signed-off-by: jianh <jianh@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
📝 WalkthroughWalkthroughThe PR introduces seq-mask-tis, a sequence-level importance sampling approach as an alternative to ICE-POP. Documentation is updated with terminology clarification and detailed comparison between methods. Implementation adds seq-mask-tis support to the loss function with validation rules and metric tracking. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/loss_functions.py`:
- Around line 171-183: In the __init__ (or initializer) validation block where
truncated_importance_sampling_type is checked (the assert on
self.truncated_importance_sampling_type and the subsequent check for
"seq-mask-tis"), add a guard that when self.truncated_importance_sampling_type
is "icepop" or "seq-mask-tis", then self.truncated_importance_sampling_ratio_min
is not None; raise an assertion with a clear message referencing
truncated_importance_sampling_ratio_min and the allowed sampling types to
prevent later TypeError when comparing tensors to None in methods that use this
attribute.
- Line 579: The metric is_filter_drop_frac is a fraction and must not be summed
across packed sequences; instead implement the same special aggregation used for
min/max metrics in SequencePackingLossWrapper: when encountering key
"is_filter_drop_frac" in metrics_accum, accumulate a weighted sum (e.g.,
metrics_accum_sum["is_filter_drop_frac"] += val * weight) and a corresponding
total weight (metrics_accum_weight["is_filter_drop_frac"] += weight) where
weight is the number of examples/sequence length for that packed segment, then
compute the final fraction as metrics_accum_sum / metrics_accum_weight before
reporting; update the code paths that currently do metrics_accum[k] += val to
detect "is_filter_drop_frac" and use this weighted accumulation and final
division (follow the same pattern used for the existing min/max handling in
SequencePackingLossWrapper and use the same helper variables/keys to keep
aggregation consistent).
🧹 Nitpick comments (3)
docs/guides/prorlv2.md (2)
158-184: Documentation for seq-mask-tis looks good overall.The new section clearly explains the mechanism, provides a comparison table with ICE-POP, and includes configuration snippets.
One observation: the reference bounds in the table (
min=0.002, max=0.003) represent a very narrow band for the geometric-mean IS ratio. Consider adding a brief note explaining why these bounds are so far from 1.0 (and so tight), or pointing users to the referenced blog for tuning guidance. Users unfamiliar with the method may assume bounds closer to 1.0 are expected.
182-182: Clarify metric semantics difference.Line 182 notes that
is_filter_drop_fracrepresents "fraction of tokens (ICE-POP) or sequences (seq-mask-tis)" filtered out. Since the same metric name measures different granularities depending on the mode, this could cause confusion when comparing runs. Consider noting this caveat more prominently or suggesting users check which mode is active when interpreting this metric.nemo_rl/algorithms/loss_functions.py (1)
418-424: Replace EN DASH (–) with HYPHEN-MINUS (-) in comments.Ruff (RUF003) flags ambiguous Unicode EN DASH characters in these comment lines. While visually similar, they can cause issues with some tools and are not idiomatic in source code.
Proposed fix
- # "tis" – clamp IS weights to [0, max] - # "icepop" – zero out tokens whose IS weight ∉ [min, max] (ref bounds: 0.5–5) - # "seq-mask-tis" – zero out entire sequences whose geometric-mean - # IS ratio ∉ [min, max]; retained sequences keep - # raw (non-truncated) token-level IS weights (ref bounds: 0.002–0.003) + # "tis" - clamp IS weights to [0, max] + # "icepop" - zero out tokens whose IS weight not in [min, max] (ref bounds: 0.5-5) + # "seq-mask-tis" - zero out entire sequences whose geometric-mean + # IS ratio not in [min, max]; retained sequences keep + # raw (non-truncated) token-level IS weights (ref bounds: 0.002-0.003)The same applies to the comment block at lines 48-51:
- # "tis" – clamp IS weights to max - # "icepop" – zero out tokens with IS weight outside [min, max] - # "seq-mask-tis" – zero out sequences by geometric-mean IS ratio, non-truncated token IS correction + # "tis" - clamp IS weights to max + # "icepop" - zero out tokens with IS weight outside [min, max] + # "seq-mask-tis" - zero out sequences by geometric-mean IS ratio, non-truncated token IS correction
|
@terrykong @yfw the l1/l0 tests passed. |
terrykong
left a comment
There was a problem hiding this comment.
thanks for contributing @hijkzzz. to help others who come across this PR, is there any experimental results you can share showing how this helps stability in your experiments?
In the Demystifying blog they show "Token-Level MIS < Sequence-Level MIS" for stability. Any reason why you implemented the token level one instead of the sequence level one first?
Also, i think it would be good to have unit tests for all these importance sampling techniques for correctness
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: hijkzzz <janhu9527@gmail.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: hijkzzz <janhu9527@gmail.com>
- Apply nan_to_num to prev_logprobs - generation_logprobs before masked_mean in seq-mask-tis, preventing inf/NaN from corrupting the geometric-mean IS ratio computation. - Rename icepop metric key to is_oob_ratio for consistency with seq-mask-tis. - Fix seq-mask-tis reference bounds in docs (0.999–1.002, not 0.002–0.003) and correct swapped yaml config values. - Add unit tests for icepop and seq-mask-tis code paths in ClippedPGLossFn, including nan_to_num coverage. Signed-off-by: jianh <jianh@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: jianh <jianh@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
We found seq-based filtering to be more stable for MoE models. |
|
@terrykong all tests passed please merge it |
|
@terrykong fixed please merge it |
Calculate the out-of-bounds ratio for the "tis" type so users can monitor how often IS weights exceed the truncation threshold, consistent with the existing metrics for "icepop" and "seq-mask-tis". Signed-off-by: jianh <jianh@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…VIDIA-NeMo#1899) Signed-off-by: jianh <jianh@nvidia.com> Signed-off-by: hijkzzz <janhu9527@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Add a new IS filtering mechanism "seq-mask-tis" that masks entire sequences based on the geometric mean of per-token IS ratios, while keeping non-truncated token-level IS weights for gradient correction. Also adds shared
is_filter_drop_fracmetric for both icepop and seq-mask-tis modes, and documents the new option in prorlv2.md.Summary by CodeRabbit
New Features
Documentation