@@ -21,6 +21,8 @@ def __init__(
2121 clip_range_high : Optional [float ] = None ,
2222 clip_ratio_c : float = 3.0 ,
2323 loss_agg_mode : Optional [str ] = "token-mean" ,
24+ enable_sequence_masking : bool = False , # introduced in DeepseekV3.2
25+ delta_sequence_masking : float = 0.1 ,
2426 ) -> None :
2527 super ().__init__ (backend = backend )
2628 if clip_range_low is None :
@@ -36,6 +38,8 @@ def __init__(
3638 assert self .clip_range_low is not None , "clip_range_low must be specified."
3739 assert self .clip_range_high is not None , "clip_range_high must be specified."
3840 self .loss_agg_mode = loss_agg_mode
41+ self .enable_sequence_masking = enable_sequence_masking
42+ self .delta_sequence_masking = delta_sequence_masking
3943
4044 def __call__ ( # type: ignore
4145 self ,
@@ -66,13 +70,45 @@ def __call__( # type: ignore
6670 torch .gt (clip_pg_losses1 , pg_losses3 ) * (advantages < 0 ).float (), action_mask
6771 )
6872 pg_losses = torch .where (advantages < 0 , clip_pg_losses2 , clip_pg_losses1 )
73+
74+ # Apply sequence masking if enabled
75+ if self .enable_sequence_masking :
76+ # Compute sequence-level KL divergence: mean KL per sequence
77+ # Shape: (batch_size, seq_len) -> (batch_size,)
78+ kl_per_token = - negative_approx_kl # KL divergence per token
79+ sequence_kl = (kl_per_token * action_mask ).sum (dim = - 1 ) / (
80+ action_mask .sum (dim = - 1 ) + 1e-10
81+ )
82+
83+ # Create mask: mask out tokens with negative advantages when sequence KL is high
84+ # Token-level advantage check: (batch_size, seq_len)
85+ has_negative_advantage = advantages < 0
86+ # Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
87+ exceeds_kl_threshold = (
88+ (sequence_kl > self .delta_sequence_masking ).unsqueeze (- 1 ).expand_as (advantages )
89+ )
90+ # Mask tokens that are both negative advantage AND in high-KL sequences
91+ should_mask = has_negative_advantage & exceeds_kl_threshold
92+ sequence_mask = (~ should_mask ).float ()
93+
94+ # Apply sequence mask to the losses
95+ pg_losses = pg_losses * sequence_mask
96+
97+ metrics_seq_mask = {
98+ "seq_mask/masked_tokens" : should_mask .float ().sum ().item ()
99+ / (action_mask .sum ().item () + 1e-10 ),
100+ "seq_mask/mean_sequence_kl" : sequence_kl .mean ().detach ().item (),
101+ }
102+
69103 pg_loss = aggregate_loss (pg_losses , action_mask , loss_agg_mode = self .loss_agg_mode )
70104 metrics = {
71105 "pg_clipfrac" : pg_clip_frac .detach ().item (),
72106 "pg_clipfrac_lower" : pg_clipfrac_lower .detach ().item (),
73107 "ppo_kl" : ppo_kl .detach ().item (),
74108 "pg_loss" : pg_loss .detach ().item (),
75109 }
110+ if self .enable_sequence_masking :
111+ metrics .update (metrics_seq_mask )
76112 return pg_loss , metrics
77113
78114 @classmethod
@@ -81,4 +117,6 @@ def default_args(cls) -> Dict:
81117 "clip_range" : 0.2 ,
82118 "clip_ratio_c" : 3.0 ,
83119 "loss_agg_mode" : "token-mean" ,
120+ "enable_sequence_masking" : False ,
121+ "delta_sequence_masking" : 0.1 ,
84122 }
0 commit comments