Skip to content

Commit 13a5a9a

Browse files
yfwHeyyyyyyG
andauthored
feat: force on-policy ratio to 1 (#1529)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com> Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com>
1 parent 5070dd1 commit 13a5a9a

File tree

11 files changed

+153
-5
lines changed

11 files changed

+153
-5
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ loss_fn:
5353
truncated_importance_sampling_ratio: null
5454
sequence_level_importance_ratios: false
5555
token_level_loss: true
56+
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)
5657

5758
checkpointing:
5859
enabled: true

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ loss_fn:
4646
use_importance_sampling_correction: false
4747
truncated_importance_sampling_ratio: null
4848
token_level_loss: true
49+
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)
4950

5051
checkpointing:
5152
enabled: true

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ loss_fn:
4242
use_importance_sampling_correction: false
4343
truncated_importance_sampling_ratio: null
4444
token_level_loss: true
45+
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)
4546
checkpointing:
4647
enabled: true
4748
checkpoint_dir: results/clevr_grpo_${policy.model_name}

examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ loss_fn:
4141
truncated_importance_sampling_ratio: null
4242
use_importance_sampling_correction: false
4343
token_level_loss: true
44+
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)
4445

4546
checkpointing:
4647
enabled: true

nemo_rl/algorithms/grpo.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,17 @@ def init_vllm():
601601

602602
loss_fn = ClippedPGLossFn(loss_config)
603603

604+
# Validate force_on_policy_ratio
605+
if loss_config.get("force_on_policy_ratio", False):
606+
assert (
607+
grpo_config["num_prompts_per_step"]
608+
* grpo_config["num_generations_per_prompt"]
609+
== policy_config["train_global_batch_size"]
610+
), (
611+
"force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt"
612+
)
613+
print(" ✓ force_on_policy_ratio enabled")
614+
604615
# Calculate total setup time
605616
total_setup_time = time.perf_counter() - setup_start_time
606617
worker_init_timing_metrics["total_setup_time_s"] = total_setup_time
@@ -1425,7 +1436,17 @@ def grpo_train(
14251436

14261437
metrics.update(train_results["all_mb_metrics"])
14271438
for k, v in metrics.items():
1428-
if k in {
1439+
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
1440+
valid_values = [x for x in v if not np.isinf(x)]
1441+
metrics[k] = (
1442+
np.min(valid_values).item() if valid_values else -1.0
1443+
)
1444+
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
1445+
valid_values = [x for x in v if not np.isinf(x)]
1446+
metrics[k] = (
1447+
np.max(valid_values).item() if valid_values else -1.0
1448+
)
1449+
elif k in {
14291450
"lr",
14301451
"wd",
14311452
"reward",
@@ -2369,7 +2390,17 @@ def async_grpo_train(
23692390
)
23702391
metrics.update(train_results["all_mb_metrics"])
23712392
for k, v in metrics.items():
2372-
if k in {
2393+
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
2394+
valid_values = [x for x in v if not np.isinf(x)]
2395+
metrics[k] = (
2396+
np.min(valid_values).item() if valid_values else -1.0
2397+
)
2398+
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
2399+
valid_values = [x for x in v if not np.isinf(x)]
2400+
metrics[k] = (
2401+
np.max(valid_values).item() if valid_values else -1.0
2402+
)
2403+
elif k in {
23732404
"lr",
23742405
"wd",
23752406
"reward",

nemo_rl/algorithms/loss_functions.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415
from typing import Any, NotRequired, Optional, TypedDict, TypeVar
1516

1617
import torch
@@ -50,6 +51,12 @@ class ClippedPGLossConfig(TypedDict):
5051
# If False (default), correction is applied at the token level as in the
5152
# original GRPO paper.
5253
sequence_level_importance_ratios: NotRequired[bool]
54+
disable_ppo_ratio: NotRequired[bool]
55+
# If True, force the ratio to 1.0 for truly on-policy behavior,
56+
# eliminating any importance sampling effects.
57+
# NOTE: This should only be used when doing exactly one update per rollout
58+
# (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size)
59+
force_on_policy_ratio: NotRequired[bool]
5360

5461

5562
class ClippedPGLossDataDict(TypedDict):
@@ -74,6 +81,7 @@ class ClippedPGLossFn(LossFunction):
7481
- GRPO - https://arxiv.org/abs/2402.03300
7582
- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
7683
- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071
84+
- Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout)
7785
7886
Formula:
7987
L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)
@@ -114,6 +122,9 @@ def __init__(self, cfg: ClippedPGLossConfig):
114122
self.kl_input_clamp_value = cfg["kl_input_clamp_value"]
115123
self.kl_output_clamp_value = cfg["kl_output_clamp_value"]
116124
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)
125+
self.force_on_policy_ratio = cfg.get(
126+
"force_on_policy_ratio", False
127+
) # Force ratio to 1.0
117128
self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"]
118129
self.use_importance_sampling_correction = cfg[
119130
"use_importance_sampling_correction"
@@ -296,7 +307,13 @@ def __call__(
296307
kl = torch.tensor(0.0)
297308

298309
# Calculate clipped loss function if ppo ratio is enabled.
299-
if not self.disable_ppo_ratio:
310+
if self.force_on_policy_ratio:
311+
# Force ratio to 1.0 for truly on-policy behavior
312+
# Use curr_logprobs twice so ratio=1 but gradients still flow
313+
log_ratios = curr_logprobs - curr_logprobs.detach()
314+
ratios = log_ratios.exp() # = exp(0) = 1.0, but depends on curr_logprobs
315+
ratios_clamped = ratios
316+
elif not self.disable_ppo_ratio:
300317
log_ratios = curr_logprobs - prev_logprobs
301318
if self.sequence_level_importance_ratios:
302319
seq_log_ratio_mean = masked_mean(
@@ -419,6 +436,22 @@ def __call__(
419436
global_normalization_factor=global_valid_toks,
420437
).item()
421438

439+
# Calculate min/max values for ratios (only for valid tokens)
440+
masked_ratios = ratios.detach()[mask.bool()]
441+
masked_ratios_clamped = ratios_clamped.detach()[mask.bool()]
442+
443+
# Handle edge case where there might be no valid tokens
444+
if masked_ratios.numel() > 0:
445+
probs_ratio_min = masked_ratios.min().item()
446+
probs_ratio_max = masked_ratios.max().item()
447+
probs_ratio_clamped_min = masked_ratios_clamped.min().item()
448+
probs_ratio_clamped_max = masked_ratios_clamped.max().item()
449+
else:
450+
probs_ratio_min = float("inf")
451+
probs_ratio_max = float("-inf")
452+
probs_ratio_clamped_min = float("inf")
453+
probs_ratio_clamped_max = float("-inf")
454+
422455
# If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized
423456
# by either sequence or token count, depending on particular metric.
424457
# To get the true metric, you'll need to sum over the microbatch.
@@ -428,6 +461,10 @@ def __call__(
428461
"loss": loss.item(),
429462
"probs_ratio": probs_ratio,
430463
"probs_ratio_clamped": probs_ratio_clamped,
464+
"probs_ratio_min": probs_ratio_min,
465+
"probs_ratio_max": probs_ratio_max,
466+
"probs_ratio_clamped_min": probs_ratio_clamped_min,
467+
"probs_ratio_clamped_max": probs_ratio_clamped_max,
431468
"kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0,
432469
"token_mult_prob_error": mult_prob_error,
433470
"gen_kl_error": gen_kl_error,
@@ -903,8 +940,24 @@ def __call__(
903940
loss_accum += loss
904941
for k, v in metrics.items():
905942
if k not in metrics_accum:
906-
metrics_accum[k] = 0
907-
metrics_accum[k] += v
943+
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
944+
metrics_accum[k] = float("inf")
945+
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
946+
metrics_accum[k] = float("-inf")
947+
else:
948+
metrics_accum[k] = 0
949+
950+
val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v
951+
952+
# Skip inf/-inf sentinel values (from sequences with no valid tokens)
953+
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
954+
if not math.isinf(val):
955+
metrics_accum[k] = min(metrics_accum[k], val)
956+
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
957+
if not math.isinf(val):
958+
metrics_accum[k] = max(metrics_accum[k], val)
959+
else:
960+
metrics_accum[k] += val
908961

909962
return loss_accum, metrics_accum
910963

tests/unit/algorithms/test_grpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ def val_iter(self):
897897
"truncated_importance_sampling_ratio": None,
898898
"sequence_level_importance_ratios": False,
899899
"token_level_loss": True,
900+
"force_on_policy_ratio": False,
900901
}
901902
)
902903
logger = MagicMock()

tests/unit/algorithms/test_loss_functions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"truncated_importance_sampling_ratio": None, # Disable TIS
4242
"sequence_level_importance_ratios": False,
4343
"token_level_loss": True,
44+
"force_on_policy_ratio": False,
4445
}
4546

4647

@@ -562,6 +563,61 @@ def test_clipped_pg_loss_reinforce_mode():
562563
torch.testing.assert_close(actual_loss, expected_loss)
563564

564565

566+
def test_clipped_pg_loss_force_on_policy_ratio():
567+
"""Tests that force_on_policy_ratio forces ratios to 1.0 while keeping gradients."""
568+
if not torch.cuda.is_available():
569+
pytest.skip("No GPU available")
570+
571+
device = "cuda"
572+
data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device)
573+
574+
cfg = deepcopy(basic_pg_loss_test_config)
575+
cfg["force_on_policy_ratio"] = True
576+
loss_fn = ClippedPGLossFn(cfg)
577+
578+
# Use same logprob pattern as PPO clipping test to ensure
579+
# that without the flag, ratios would be [0.5, 1.0, 1.5]
580+
adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device)
581+
prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
582+
curr_lp_masked = torch.tensor(
583+
[[-1.69315, -1.0, -0.59453]], device=device
584+
) # approx log(0.5)-1, log(1)-1, log(1.5)-1
585+
586+
# Fill full tensors (only need first dim for B=1)
587+
data["advantages"][0, 1:] = adv_masked
588+
data["prev_logprobs"][0, 1:] = prev_lp_masked
589+
590+
# Hand-calculated expected loss when ratios are forced to 1.0
591+
ratios = torch.ones_like(adv_masked, device=device)
592+
loss_per_token = -adv_masked * ratios # [-1.0, 1.0, -2.0]
593+
expected_loss = torch.mean(loss_per_token) # (-1 + 1 - 2) / 3 = -0.6666...
594+
595+
input_ids = data["input_ids"]
596+
dummy_logits = _create_exact_logits(
597+
curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device
598+
)
599+
600+
actual_loss, metrics = loss_fn(
601+
dummy_logits,
602+
data,
603+
global_valid_seqs=torch.sum(data["sample_mask"]),
604+
global_valid_toks=torch.sum(
605+
data["sample_mask"].unsqueeze(-1) * data["token_mask"]
606+
),
607+
)
608+
609+
# Loss should match the on-policy expectation
610+
torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-3)
611+
612+
# Ratios and their metrics should all be exactly 1.0
613+
assert metrics["probs_ratio"] == 1.0
614+
assert metrics["probs_ratio_clamped"] == 1.0
615+
assert metrics["probs_ratio_min"] == 1.0
616+
assert metrics["probs_ratio_max"] == 1.0
617+
assert metrics["probs_ratio_clamped_min"] == 1.0
618+
assert metrics["probs_ratio_clamped_max"] == 1.0
619+
620+
565621
@pytest.mark.parametrize("kl_type", ["k1", "k2", "k3"])
566622
def test_calculate_kl(kl_type):
567623
"""Tests KL calculations."""

tests/unit/algorithms/test_sequence_packing_gradients.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def test_sequence_packing_gradients(self):
139139
"truncated_importance_sampling_ratio": None,
140140
"sequence_level_importance_ratios": False,
141141
"token_level_loss": True,
142+
"force_on_policy_ratio": False,
142143
}
143144

144145
base_loss_fn = ClippedPGLossFn(loss_config)

tests/unit/models/policy/test_dtensor_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus(
681681
"truncated_importance_sampling_ratio": None,
682682
"sequence_level_importance_ratios": False,
683683
"token_level_loss": True,
684+
"force_on_policy_ratio": False,
684685
}
685686
)
686687

0 commit comments

Comments
 (0)