@@ -23,6 +23,7 @@ def __init__(
2323 loss_agg_mode : Optional [str ] = "token-mean" ,
2424 enable_sequence_masking : bool = False , # introduced in DeepseekV3.2
2525 delta_sequence_masking : float = 0.1 ,
26+ fallback_to_policy_gradient : bool = False ,
2627 ) -> None :
2728 super ().__init__ (backend = backend )
2829 if clip_range_low is None :
@@ -40,6 +41,7 @@ def __init__(
4041 self .loss_agg_mode = loss_agg_mode
4142 self .enable_sequence_masking = enable_sequence_masking
4243 self .delta_sequence_masking = delta_sequence_masking
44+ self .fallback_to_policy_gradient = fallback_to_policy_gradient
4345
4446 def __call__ ( # type: ignore
4547 self ,
@@ -50,6 +52,9 @@ def __call__( # type: ignore
5052 ** kwargs ,
5153 ) -> Tuple [torch .Tensor , Dict ]:
5254 negative_approx_kl = logprob - old_logprob
55+ if self .fallback_to_policy_gradient :
56+ # ignore vllm logprob difference and use pure policy gradient loss
57+ negative_approx_kl = logprob - logprob .detach ()
5358 # Clamp negative_approx_kl for stability
5459 negative_approx_kl = torch .clamp (negative_approx_kl , min = - 20.0 , max = 20.0 )
5560 ratio = torch .exp (negative_approx_kl )
@@ -119,4 +124,5 @@ def default_args(cls) -> Dict:
119124 "loss_agg_mode" : "token-mean" ,
120125 "enable_sequence_masking" : False ,
121126 "delta_sequence_masking" : 0.1 ,
127+ "fallback_to_policy_gradient" : False ,
122128 }
0 commit comments