Skip to content
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ loss_fn:
truncated_importance_sampling_ratio: null
sequence_level_importance_ratios: false
token_level_loss: true
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)

checkpointing:
enabled: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ loss_fn:
use_importance_sampling_correction: false
truncated_importance_sampling_ratio: null
token_level_loss: true
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)

checkpointing:
enabled: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ loss_fn:
use_importance_sampling_correction: false
truncated_importance_sampling_ratio: null
token_level_loss: true
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)
checkpointing:
enabled: true
checkpoint_dir: results/clevr_grpo_${policy.model_name}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ loss_fn:
truncated_importance_sampling_ratio: null
use_importance_sampling_correction: false
token_level_loss: true
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)

checkpointing:
enabled: true
Expand Down
35 changes: 33 additions & 2 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,17 @@ def init_vllm():

loss_fn = ClippedPGLossFn(loss_config)

# Validate force_on_policy_ratio
if loss_config.get("force_on_policy_ratio", False):
assert (
grpo_config["num_prompts_per_step"]
* grpo_config["num_generations_per_prompt"]
== policy_config["train_global_batch_size"]
), (
"force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt"
)
print(" ✓ force_on_policy_ratio enabled")

# Calculate total setup time
total_setup_time = time.perf_counter() - setup_start_time
worker_init_timing_metrics["total_setup_time_s"] = total_setup_time
Expand Down Expand Up @@ -1342,7 +1353,17 @@ def grpo_train(

metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.min(valid_values).item() if valid_values else -1.0
)
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.max(valid_values).item() if valid_values else -1.0
)
elif k in {
"lr",
"wd",
"reward",
Expand Down Expand Up @@ -2270,7 +2291,17 @@ def async_grpo_train(
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.min(valid_values).item() if valid_values else -1.0
)
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
valid_values = [x for x in v if not np.isinf(x)]
metrics[k] = (
np.max(valid_values).item() if valid_values else -1.0
)
elif k in {
"lr",
"wd",
"reward",
Expand Down
60 changes: 57 additions & 3 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, NotRequired, Optional, TypedDict, TypeVar

import torch
Expand Down Expand Up @@ -45,11 +46,18 @@ class ClippedPGLossConfig(TypedDict):
use_importance_sampling_correction: bool
truncated_importance_sampling_ratio: float | None
token_level_loss: bool
force_on_policy_ratio: bool
# If True, apply the off-policy importance-sampling correction at the
# sequence level (one weight per generated sample), as in GSPO.
# If False (default), correction is applied at the token level as in the
# original GRPO paper.
sequence_level_importance_ratios: NotRequired[bool]
disable_ppo_ratio: NotRequired[bool]
# If True, force the ratio to 1.0 for truly on-policy behavior,
# eliminating any importance sampling effects.
# NOTE: This should only be used when doing exactly one update per rollout
# (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size)
force_on_policy_ratio: NotRequired[bool]


class ClippedPGLossDataDict(TypedDict):
Expand All @@ -74,6 +82,7 @@ class ClippedPGLossFn(LossFunction):
- GRPO - https://arxiv.org/abs/2402.03300
- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071
- Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout)

Formula:
L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)
Expand Down Expand Up @@ -114,6 +123,9 @@ def __init__(self, cfg: ClippedPGLossConfig):
self.kl_input_clamp_value = cfg["kl_input_clamp_value"]
self.kl_output_clamp_value = cfg["kl_output_clamp_value"]
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)
self.force_on_policy_ratio = cfg.get(
"force_on_policy_ratio", False
) # Force ratio to 1.0
self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"]
self.use_importance_sampling_correction = cfg[
"use_importance_sampling_correction"
Expand Down Expand Up @@ -296,7 +308,13 @@ def __call__(
kl = torch.tensor(0.0)

# Calculate clipped loss function if ppo ratio is enabled.
if not self.disable_ppo_ratio:
if self.force_on_policy_ratio:
# Force ratio to 1.0 for truly on-policy behavior
# Use curr_logprobs twice so ratio=1 but gradients still flow
log_ratios = curr_logprobs - curr_logprobs.detach()
ratios = log_ratios.exp() # = exp(0) = 1.0, but depends on curr_logprobs
ratios_clamped = ratios
elif not self.disable_ppo_ratio:
log_ratios = curr_logprobs - prev_logprobs
if self.sequence_level_importance_ratios:
seq_log_ratio_mean = masked_mean(
Expand Down Expand Up @@ -419,6 +437,22 @@ def __call__(
global_normalization_factor=global_valid_toks,
).item()

# Calculate min/max values for ratios (only for valid tokens)
masked_ratios = ratios.detach()[mask.bool()]
masked_ratios_clamped = ratios_clamped.detach()[mask.bool()]

# Handle edge case where there might be no valid tokens
if masked_ratios.numel() > 0:
probs_ratio_min = masked_ratios.min().item()
probs_ratio_max = masked_ratios.max().item()
probs_ratio_clamped_min = masked_ratios_clamped.min().item()
probs_ratio_clamped_max = masked_ratios_clamped.max().item()
else:
probs_ratio_min = float("inf")
probs_ratio_max = float("-inf")
probs_ratio_clamped_min = float("inf")
probs_ratio_clamped_max = float("-inf")

# If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized
# by either sequence or token count, depending on particular metric.
# To get the true metric, you'll need to sum over the microbatch.
Expand All @@ -428,6 +462,10 @@ def __call__(
"loss": loss.item(),
"probs_ratio": probs_ratio,
"probs_ratio_clamped": probs_ratio_clamped,
"probs_ratio_min": probs_ratio_min,
"probs_ratio_max": probs_ratio_max,
"probs_ratio_clamped_min": probs_ratio_clamped_min,
"probs_ratio_clamped_max": probs_ratio_clamped_max,
"kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0,
"token_mult_prob_error": mult_prob_error,
"gen_kl_error": gen_kl_error,
Expand Down Expand Up @@ -903,8 +941,24 @@ def __call__(
loss_accum += loss
for k, v in metrics.items():
if k not in metrics_accum:
metrics_accum[k] = 0
metrics_accum[k] += v
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
metrics_accum[k] = float("inf")
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
metrics_accum[k] = float("-inf")
else:
metrics_accum[k] = 0

val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v

# Skip inf/-inf sentinel values (from sequences with no valid tokens)
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
if not math.isinf(val):
metrics_accum[k] = min(metrics_accum[k], val)
elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}:
if not math.isinf(val):
metrics_accum[k] = max(metrics_accum[k], val)
else:
metrics_accum[k] += val

return loss_accum, metrics_accum

Expand Down
1 change: 1 addition & 0 deletions tests/unit/algorithms/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def val_iter(self):
"truncated_importance_sampling_ratio": None,
"sequence_level_importance_ratios": False,
"token_level_loss": True,
"force_on_policy_ratio": False,
}
)
logger = MagicMock()
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/algorithms/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"truncated_importance_sampling_ratio": None, # Disable TIS
"sequence_level_importance_ratios": False,
"token_level_loss": True,
"force_on_policy_ratio": False,
}


Expand Down Expand Up @@ -562,6 +563,61 @@ def test_clipped_pg_loss_reinforce_mode():
torch.testing.assert_close(actual_loss, expected_loss)


def test_clipped_pg_loss_force_on_policy_ratio():
"""Tests that force_on_policy_ratio forces ratios to 1.0 while keeping gradients."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

device = "cuda"
data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device)

cfg = deepcopy(basic_pg_loss_test_config)
cfg["force_on_policy_ratio"] = True
loss_fn = ClippedPGLossFn(cfg)

# Use same logprob pattern as PPO clipping test to ensure
# that without the flag, ratios would be [0.5, 1.0, 1.5]
adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device)
prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
curr_lp_masked = torch.tensor(
[[-1.69315, -1.0, -0.59453]], device=device
) # approx log(0.5)-1, log(1)-1, log(1.5)-1

# Fill full tensors (only need first dim for B=1)
data["advantages"][0, 1:] = adv_masked
data["prev_logprobs"][0, 1:] = prev_lp_masked

# Hand-calculated expected loss when ratios are forced to 1.0
ratios = torch.ones_like(adv_masked, device=device)
loss_per_token = -adv_masked * ratios # [-1.0, 1.0, -2.0]
expected_loss = torch.mean(loss_per_token) # (-1 + 1 - 2) / 3 = -0.6666...

input_ids = data["input_ids"]
dummy_logits = _create_exact_logits(
curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device
)

actual_loss, metrics = loss_fn(
dummy_logits,
data,
global_valid_seqs=torch.sum(data["sample_mask"]),
global_valid_toks=torch.sum(
data["sample_mask"].unsqueeze(-1) * data["token_mask"]
),
)

# Loss should match the on-policy expectation
torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-3)

# Ratios and their metrics should all be exactly 1.0
assert metrics["probs_ratio"] == 1.0
assert metrics["probs_ratio_clamped"] == 1.0
assert metrics["probs_ratio_min"] == 1.0
assert metrics["probs_ratio_max"] == 1.0
assert metrics["probs_ratio_clamped_min"] == 1.0
assert metrics["probs_ratio_clamped_max"] == 1.0


@pytest.mark.parametrize("kl_type", ["k1", "k2", "k3"])
def test_calculate_kl(kl_type):
"""Tests KL calculations."""
Expand Down
1 change: 1 addition & 0 deletions tests/unit/algorithms/test_sequence_packing_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_sequence_packing_gradients(self):
"truncated_importance_sampling_ratio": None,
"sequence_level_importance_ratios": False,
"token_level_loss": True,
"force_on_policy_ratio": False,
}

base_loss_fn = ClippedPGLossFn(loss_config)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus(
"truncated_importance_sampling_ratio": None,
"sequence_level_importance_ratios": False,
"token_level_loss": True,
"force_on_policy_ratio": False,
}
)

Expand Down
1 change: 1 addition & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"truncated_importance_sampling_ratio": None,
"sequence_level_importance_ratios": False,
"token_level_loss": True,
"force_on_policy_ratio": False,
}


Expand Down
Loading