@@ -21,8 +21,8 @@ def __init__(
2121 clip_range_high : Optional [float ] = None ,
2222 clip_ratio_c : float = 3.0 ,
2323 loss_agg_mode : Optional [str ] = "token-mean" ,
24- enable_sequence_masking : bool = False ,
25- delta : float = 0.1 ,
24+ enable_sequence_masking : bool = False , # introduced in DeepseekV3.2
25+ delta_sequence_masking : float = 0.1 ,
2626 ) -> None :
2727 super ().__init__ (backend = backend )
2828 if clip_range_low is None :
@@ -39,7 +39,7 @@ def __init__(
3939 assert self .clip_range_high is not None , "clip_range_high must be specified."
4040 self .loss_agg_mode = loss_agg_mode
4141 self .enable_sequence_masking = enable_sequence_masking
42- self .delta = delta
42+ self .delta_sequence_masking = delta_sequence_masking
4343
4444 def __call__ ( # type: ignore
4545 self ,
@@ -55,8 +55,23 @@ def __call__( # type: ignore
5555 ratio = torch .exp (negative_approx_kl )
5656 ppo_kl = masked_mean (- negative_approx_kl , action_mask )
5757
58- # Compute sequence masking if enabled
59- sequence_mask = torch .ones_like (advantages )
58+ pg_losses1 = - advantages * ratio
59+ pg_losses2 = - advantages * torch .clamp (
60+ ratio , 1.0 - self .clip_range_low , 1.0 + self .clip_range_high # type: ignore
61+ )
62+
63+ clip_pg_losses1 = torch .maximum (pg_losses1 , pg_losses2 )
64+
65+ pg_clip_frac = masked_mean (torch .gt (pg_losses2 , pg_losses1 ).float (), action_mask )
66+
67+ pg_losses3 = - advantages * self .clip_ratio_c
68+ clip_pg_losses2 = torch .min (pg_losses3 , clip_pg_losses1 )
69+ pg_clipfrac_lower = masked_mean (
70+ torch .gt (clip_pg_losses1 , pg_losses3 ) * (advantages < 0 ).float (), action_mask
71+ )
72+ pg_losses = torch .where (advantages < 0 , clip_pg_losses2 , clip_pg_losses1 )
73+
74+ # Apply sequence masking if enabled
6075 if self .enable_sequence_masking :
6176 # Compute sequence-level KL divergence: mean KL per sequence
6277 # Shape: (batch_size, seq_len) -> (batch_size,)
@@ -69,38 +84,21 @@ def __call__( # type: ignore
6984 # Token-level advantage check: (batch_size, seq_len)
7085 has_negative_advantage = advantages < 0
7186 # Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
72- exceeds_kl_threshold = (sequence_kl > self .delta ).unsqueeze (- 1 ).expand_as (advantages )
87+ exceeds_kl_threshold = (
88+ (sequence_kl > self .delta_sequence_masking ).unsqueeze (- 1 ).expand_as (advantages )
89+ )
7390 # Mask tokens that are both negative advantage AND in high-KL sequences
7491 should_mask = has_negative_advantage & exceeds_kl_threshold
7592 sequence_mask = (~ should_mask ).float ()
7693
94+ # Apply sequence mask to the losses
95+ pg_losses = pg_losses * sequence_mask
96+
7797 metrics_seq_mask = {
7898 "seq_mask/masked_tokens" : should_mask .float ().sum ().item ()
7999 / (action_mask .sum ().item () + 1e-10 ),
80100 "seq_mask/mean_sequence_kl" : sequence_kl .mean ().detach ().item (),
81101 }
82- else :
83- metrics_seq_mask = {}
84-
85- pg_losses1 = - advantages * ratio
86- pg_losses2 = - advantages * torch .clamp (
87- ratio , 1.0 - self .clip_range_low , 1.0 + self .clip_range_high # type: ignore
88- )
89-
90- clip_pg_losses1 = torch .maximum (pg_losses1 , pg_losses2 )
91-
92- pg_clip_frac = masked_mean (torch .gt (pg_losses2 , pg_losses1 ).float (), action_mask )
93-
94- pg_losses3 = - advantages * self .clip_ratio_c
95- clip_pg_losses2 = torch .min (pg_losses3 , clip_pg_losses1 )
96- pg_clipfrac_lower = masked_mean (
97- torch .gt (clip_pg_losses1 , pg_losses3 ) * (advantages < 0 ).float (), action_mask
98- )
99- pg_losses = torch .where (advantages < 0 , clip_pg_losses2 , clip_pg_losses1 )
100-
101- # Apply sequence mask to the losses
102- if self .enable_sequence_masking :
103- pg_losses = pg_losses * sequence_mask
104102
105103 pg_loss = aggregate_loss (pg_losses , action_mask , loss_agg_mode = self .loss_agg_mode )
106104 metrics = {
@@ -109,7 +107,8 @@ def __call__( # type: ignore
109107 "ppo_kl" : ppo_kl .detach ().item (),
110108 "pg_loss" : pg_loss .detach ().item (),
111109 }
112- metrics .update (metrics_seq_mask )
110+ if self .enable_sequence_masking :
111+ metrics .update (metrics_seq_mask )
113112 return pg_loss , metrics
114113
115114 @classmethod
@@ -119,5 +118,5 @@ def default_args(cls) -> Dict:
119118 "clip_ratio_c" : 3.0 ,
120119 "loss_agg_mode" : "token-mean" ,
121120 "enable_sequence_masking" : False ,
122- "delta " : 0.1 ,
121+ "delta_sequence_masking " : 0.1 ,
123122 }
0 commit comments