Skip to content

Commit f2de476

Browse files
authored
feat: support truncated importance sampling (#1348)
Signed-off-by: Yuki Huang <[email protected]>
1 parent f286857 commit f2de476

File tree

9 files changed

+239
-157
lines changed

9 files changed

+239
-157
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ loss_fn:
4343
# Async GRPO requires importance sampling correction enabled
4444
# Set to true when async_grpo.enabled is true
4545
use_importance_sampling_correction: false
46+
truncated_importance_sampling_ratio: null
4647
sequence_level_importance_ratios: false
4748
token_level_loss: true
4849

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ loss_fn:
3939
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
4040
use_on_policy_kl_approximation: false
4141
use_importance_sampling_correction: false
42+
truncated_importance_sampling_ratio: null
4243
token_level_loss: true
4344

4445
checkpointing:

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ loss_fn:
3535
ratio_clip_c: null
3636
use_on_policy_kl_approximation: false
3737
use_importance_sampling_correction: false
38+
truncated_importance_sampling_ratio: null
3839
token_level_loss: true
3940
checkpointing:
4041
enabled: true

nemo_rl/algorithms/loss_functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class ClippedPGLossConfig(TypedDict):
4242
ratio_clip_c: float
4343
use_on_policy_kl_approximation: bool
4444
use_importance_sampling_correction: bool
45+
truncated_importance_sampling_ratio: float | None
4546
token_level_loss: bool
4647
# If True, apply the off-policy importance-sampling correction at the
4748
# sequence level (one weight per generated sample), as in GSPO.
@@ -113,6 +114,9 @@ def __init__(self, cfg: ClippedPGLossConfig):
113114
self.use_importance_sampling_correction = cfg[
114115
"use_importance_sampling_correction"
115116
]
117+
self.truncated_importance_sampling_ratio = cfg[
118+
"truncated_importance_sampling_ratio"
119+
]
116120
# Whether to compute importance weights per-sequence instead of per-token.
117121
self.sequence_level_importance_ratios = cfg.get(
118122
"sequence_level_importance_ratios",
@@ -125,6 +129,13 @@ def __init__(self, cfg: ClippedPGLossConfig):
125129
assert self.loss_type == LossType.SEQUENCE_LEVEL, (
126130
"sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss"
127131
)
132+
if self.truncated_importance_sampling_ratio is not None:
133+
assert self.use_importance_sampling_correction, (
134+
"truncated_importance_sampling_ratio is only supported when use_importance_sampling_correction is True"
135+
)
136+
assert self.truncated_importance_sampling_ratio > 0, (
137+
"truncated_importance_sampling_ratio should be positive"
138+
)
128139

129140
def __call__(
130141
self,
@@ -280,6 +291,12 @@ def __call__(
280291
actor_importance_weights_expanded = torch.nan_to_num(
281292
actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0
282293
)
294+
# TIS see https://fengyao.notion.site/off-policy-rl
295+
if self.truncated_importance_sampling_ratio is not None:
296+
actor_importance_weights_expanded = torch.clamp(
297+
actor_importance_weights_expanded,
298+
max=self.truncated_importance_sampling_ratio,
299+
)
283300
actor_importance_weights = actor_importance_weights_expanded
284301
del actor_importance_weights_expanded
285302
if self.use_importance_sampling_correction:

tests/unit/algorithms/test_grpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,8 @@ def val_iter(self):
889889
"ratio_clip_c": 1.0,
890890
"use_on_policy_kl_approximation": False,
891891
"use_importance_sampling_correction": False,
892+
"truncated_importance_sampling_ratio": None,
893+
"sequence_level_importance_ratios": False,
892894
"token_level_loss": True,
893895
}
894896
)

0 commit comments

Comments
 (0)