@@ -21,6 +21,8 @@ def __init__(
2121 clip_range_high : Optional [float ] = None ,
2222 clip_ratio_c : float = 3.0 ,
2323 loss_agg_mode : Optional [str ] = "token-mean" ,
24+ enable_sequence_masking : bool = False ,
25+ delta : float = 0.1 ,
2426 ) -> None :
2527 super ().__init__ (backend = backend )
2628 if clip_range_low is None :
@@ -36,6 +38,8 @@ def __init__(
3638 assert self .clip_range_low is not None , "clip_range_low must be specified."
3739 assert self .clip_range_high is not None , "clip_range_high must be specified."
3840 self .loss_agg_mode = loss_agg_mode
41+ self .enable_sequence_masking = enable_sequence_masking
42+ self .delta = delta
3943
4044 def __call__ ( # type: ignore
4145 self ,
@@ -51,6 +55,33 @@ def __call__( # type: ignore
5155 ratio = torch .exp (negative_approx_kl )
5256 ppo_kl = masked_mean (- negative_approx_kl , action_mask )
5357
58+ # Compute sequence masking if enabled
59+ sequence_mask = torch .ones_like (advantages )
60+ if self .enable_sequence_masking :
61+ # Compute sequence-level KL divergence: mean KL per sequence
62+ # Shape: (batch_size, seq_len) -> (batch_size,)
63+ kl_per_token = - negative_approx_kl # KL divergence per token
64+ sequence_kl = (kl_per_token * action_mask ).sum (dim = - 1 ) / (
65+ action_mask .sum (dim = - 1 ) + 1e-10
66+ )
67+
68+ # Create mask: mask out tokens with negative advantages when sequence KL is high
69+ # Token-level advantage check: (batch_size, seq_len)
70+ has_negative_advantage = advantages < 0
71+ # Sequence-level KL check: (batch_size,) -> (batch_size, 1) -> (batch_size, seq_len)
72+ exceeds_kl_threshold = (sequence_kl > self .delta ).unsqueeze (- 1 ).expand_as (advantages )
73+ # Mask tokens that are both negative advantage AND in high-KL sequences
74+ should_mask = has_negative_advantage & exceeds_kl_threshold
75+ sequence_mask = (~ should_mask ).float ()
76+
77+ metrics_seq_mask = {
78+ "seq_mask/masked_tokens" : should_mask .float ().sum ().item ()
79+ / (action_mask .sum ().item () + 1e-10 ),
80+ "seq_mask/mean_sequence_kl" : sequence_kl .mean ().detach ().item (),
81+ }
82+ else :
83+ metrics_seq_mask = {}
84+
5485 pg_losses1 = - advantages * ratio
5586 pg_losses2 = - advantages * torch .clamp (
5687 ratio , 1.0 - self .clip_range_low , 1.0 + self .clip_range_high # type: ignore
@@ -66,13 +97,18 @@ def __call__( # type: ignore
6697 torch .gt (clip_pg_losses1 , pg_losses3 ) * (advantages < 0 ).float (), action_mask
6798 )
6899 pg_losses = torch .where (advantages < 0 , clip_pg_losses2 , clip_pg_losses1 )
100+
101+ # Apply sequence mask to the losses
102+ pg_losses = pg_losses * sequence_mask
103+
69104 pg_loss = aggregate_loss (pg_losses , action_mask , loss_agg_mode = self .loss_agg_mode )
70105 metrics = {
71106 "pg_clipfrac" : pg_clip_frac .detach ().item (),
72107 "pg_clipfrac_lower" : pg_clipfrac_lower .detach ().item (),
73108 "ppo_kl" : ppo_kl .detach ().item (),
74109 "pg_loss" : pg_loss .detach ().item (),
75110 }
111+ metrics .update (metrics_seq_mask )
76112 return pg_loss , metrics
77113
78114 @classmethod
@@ -81,4 +117,6 @@ def default_args(cls) -> Dict:
81117 "clip_range" : 0.2 ,
82118 "clip_ratio_c" : 3.0 ,
83119 "loss_agg_mode" : "token-mean" ,
120+ "enable_sequence_masking" : False ,
121+ "delta" : 0.1 ,
84122 }
0 commit comments