1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import math
1415from typing import Any , NotRequired , Optional , TypedDict , TypeVar
1516
1617import torch
@@ -50,6 +51,12 @@ class ClippedPGLossConfig(TypedDict):
5051 # If False (default), correction is applied at the token level as in the
5152 # original GRPO paper.
5253 sequence_level_importance_ratios : NotRequired [bool ]
54+ disable_ppo_ratio : NotRequired [bool ]
55+ # If True, force the ratio to 1.0 for truly on-policy behavior,
56+ # eliminating any importance sampling effects.
57+ # NOTE: This should only be used when doing exactly one update per rollout
58+ # (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size)
59+ force_on_policy_ratio : NotRequired [bool ]
5360
5461
5562class ClippedPGLossDataDict (TypedDict ):
@@ -74,6 +81,7 @@ class ClippedPGLossFn(LossFunction):
7481 - GRPO - https://arxiv.org/abs/2402.03300
7582 - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
7683 - GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071
84+ - Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout)
7785
7886 Formula:
7987 L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)
@@ -114,6 +122,9 @@ def __init__(self, cfg: ClippedPGLossConfig):
114122 self .kl_input_clamp_value = cfg ["kl_input_clamp_value" ]
115123 self .kl_output_clamp_value = cfg ["kl_output_clamp_value" ]
116124 self .disable_ppo_ratio = cfg .get ("disable_ppo_ratio" , False )
125+ self .force_on_policy_ratio = cfg .get (
126+ "force_on_policy_ratio" , False
127+ ) # Force ratio to 1.0
117128 self .use_on_policy_kl_approximation = cfg ["use_on_policy_kl_approximation" ]
118129 self .use_importance_sampling_correction = cfg [
119130 "use_importance_sampling_correction"
@@ -296,7 +307,13 @@ def __call__(
296307 kl = torch .tensor (0.0 )
297308
298309 # Calculate clipped loss function if ppo ratio is enabled.
299- if not self .disable_ppo_ratio :
310+ if self .force_on_policy_ratio :
311+ # Force ratio to 1.0 for truly on-policy behavior
312+ # Use curr_logprobs twice so ratio=1 but gradients still flow
313+ log_ratios = curr_logprobs - curr_logprobs .detach ()
314+ ratios = log_ratios .exp () # = exp(0) = 1.0, but depends on curr_logprobs
315+ ratios_clamped = ratios
316+ elif not self .disable_ppo_ratio :
300317 log_ratios = curr_logprobs - prev_logprobs
301318 if self .sequence_level_importance_ratios :
302319 seq_log_ratio_mean = masked_mean (
@@ -419,6 +436,22 @@ def __call__(
419436 global_normalization_factor = global_valid_toks ,
420437 ).item ()
421438
439+ # Calculate min/max values for ratios (only for valid tokens)
440+ masked_ratios = ratios .detach ()[mask .bool ()]
441+ masked_ratios_clamped = ratios_clamped .detach ()[mask .bool ()]
442+
443+ # Handle edge case where there might be no valid tokens
444+ if masked_ratios .numel () > 0 :
445+ probs_ratio_min = masked_ratios .min ().item ()
446+ probs_ratio_max = masked_ratios .max ().item ()
447+ probs_ratio_clamped_min = masked_ratios_clamped .min ().item ()
448+ probs_ratio_clamped_max = masked_ratios_clamped .max ().item ()
449+ else :
450+ probs_ratio_min = float ("inf" )
451+ probs_ratio_max = float ("-inf" )
452+ probs_ratio_clamped_min = float ("inf" )
453+ probs_ratio_clamped_max = float ("-inf" )
454+
422455 # If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized
423456 # by either sequence or token count, depending on particular metric.
424457 # To get the true metric, you'll need to sum over the microbatch.
@@ -428,6 +461,10 @@ def __call__(
428461 "loss" : loss .item (),
429462 "probs_ratio" : probs_ratio ,
430463 "probs_ratio_clamped" : probs_ratio_clamped ,
464+ "probs_ratio_min" : probs_ratio_min ,
465+ "probs_ratio_max" : probs_ratio_max ,
466+ "probs_ratio_clamped_min" : probs_ratio_clamped_min ,
467+ "probs_ratio_clamped_max" : probs_ratio_clamped_max ,
431468 "kl_penalty" : kl .item () / self .reference_policy_kl_penalty if kl else 0 ,
432469 "token_mult_prob_error" : mult_prob_error ,
433470 "gen_kl_error" : gen_kl_error ,
@@ -903,8 +940,24 @@ def __call__(
903940 loss_accum += loss
904941 for k , v in metrics .items ():
905942 if k not in metrics_accum :
906- metrics_accum [k ] = 0
907- metrics_accum [k ] += v
943+ if k in {"probs_ratio_min" , "probs_ratio_clamped_min" }:
944+ metrics_accum [k ] = float ("inf" )
945+ elif k in {"probs_ratio_max" , "probs_ratio_clamped_max" }:
946+ metrics_accum [k ] = float ("-inf" )
947+ else :
948+ metrics_accum [k ] = 0
949+
950+ val = v .item () if isinstance (v , torch .Tensor ) and v .ndim == 0 else v
951+
952+ # Skip inf/-inf sentinel values (from sequences with no valid tokens)
953+ if k in {"probs_ratio_min" , "probs_ratio_clamped_min" }:
954+ if not math .isinf (val ):
955+ metrics_accum [k ] = min (metrics_accum [k ], val )
956+ elif k in {"probs_ratio_max" , "probs_ratio_clamped_max" }:
957+ if not math .isinf (val ):
958+ metrics_accum [k ] = max (metrics_accum [k ], val )
959+ else :
960+ metrics_accum [k ] += val
908961
909962 return loss_accum , metrics_accum
910963
0 commit comments