@@ -20,7 +20,8 @@ def __init__(
2020 clip_range_low : Optional [float ] = None ,
2121 clip_range_high : Optional [float ] = None ,
2222 loss_agg_mode : Optional [str ] = "token-mean" ,
23- truncate_large_is : bool = False ,
23+ truncate_adv_pos_is : bool = False ,
24+ truncate_adv_neg_is : bool = False ,
2425 truncate_is_range_low : Optional [float ] = 0.0 ,
2526 truncate_is_range_high : Optional [float ] = 2.0 ,
2627 ) -> None :
@@ -33,8 +34,12 @@ def __init__(
3334 clip_range_low: Lower bound for clipping (1.0 - clip_range_low)
3435 clip_range_high: Upper bound for clipping (1.0 + clip_range_high)
3536 loss_agg_mode: Loss aggregation mode (default: "token-mean")
36- truncate_large_is: Whether to truncate large importance sampling ratios
37- to handle calculation discrepancies between rollout and training engines
37+ truncate_adv_pos_is: Whether to truncate large importance sampling ratios
38+ when advantage is positive to handle calculation discrepancies between
39+ rollout and training engines
40+ truncate_adv_neg_is: Whether to truncate large importance sampling ratios
41+ when advantage is negative to handle calculation discrepancies between
42+ rollout and training engines
3843 truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0)
3944 truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0)
4045 """
@@ -52,17 +57,27 @@ def __init__(
5257 self .loss_agg_mode = loss_agg_mode
5358
5459 # Truncate large IS configuration
55- self .truncate_large_is = truncate_large_is
56- if truncate_large_is :
60+ self .truncate_adv_pos_is = truncate_adv_pos_is
61+ self .truncate_adv_neg_is = truncate_adv_neg_is
62+ if truncate_adv_pos_is :
5763 self .truncate_is_range_low = truncate_is_range_low
58- self .truncate_is_range_high = truncate_is_range_high
5964 assert (
6065 self .truncate_is_range_low is not None
6166 ), "truncate_is_range_low must be specified."
67+ assert (
68+ self .truncate_is_range_low >= 0.0
69+ ), "truncate_is_range_low must be non-negative."
70+ assert (self .truncate_is_range_low < 1.0 - self .clip_range_low
71+ ), "truncate_is_range_low must be less than 1.0 - clip_range_low."
72+ if truncate_adv_neg_is :
73+ self .truncate_is_range_high = truncate_is_range_high
6274 assert (
6375 self .truncate_is_range_high is not None
6476 ), "truncate_is_range_high must be specified."
65- assert self .truncate_is_range_low >= 0.0 , "truncate_is_range_low must be non-negative."
77+ assert (
78+ self .truncate_is_range_high > 1.0 + self .clip_range_high
79+ ), "truncate_is_range_high must be greater than clip_range_high + 1.0."
80+ if truncate_adv_pos_is and truncate_adv_neg_is :
6681 assert (
6782 self .truncate_is_range_high > self .truncate_is_range_low
6883 ), "truncate_is_range_high must be greater than truncate_is_range_low."
@@ -79,36 +94,54 @@ def __call__( # type: ignore
7994 ratio = torch .exp (negative_approx_kl )
8095 ppo_kl = masked_mean (- negative_approx_kl , action_mask )
8196
82- # Truncate large IS ratios if enabled
83- # This helps stabilize training when there are calculation discrepancies between
84- # rollout and training engines, especially for small probabilities
85- if self .truncate_large_is :
86- # Track how often truncation occurs (before actually truncating)
87- # More efficient than cloning: directly check which values fall outside bounds
88- ratio_detached = ratio .detach ()
89- is_truncate_frac = masked_mean (
90- (ratio_detached < self .truncate_is_range_low ).float (), action_mask
91- ) + masked_mean ((ratio_detached > self .truncate_is_range_high ).float (), action_mask )
92- ratio = torch .clamp (ratio , self .truncate_is_range_low , self .truncate_is_range_high )
93-
94- pg_losses = - advantages * ratio
97+ # First clipping by clip_range, and calculate pg_clipfrac
98+ pg_losses1 = - advantages * ratio
9599 pg_losses2 = - advantages * torch .clamp (
96100 ratio , 1.0 - self .clip_range_low , 1.0 + self .clip_range_high # type: ignore
97101 )
102+ pg_losses_clip = torch .maximum (pg_losses1 , pg_losses2 )
103+ pg_clipfrac = masked_mean (torch .gt (pg_losses2 , pg_losses1 ).float (), action_mask )
104+
105+ # After clipped by clip_range, further truncate IS ratios if enabled
106+ # This helps stabilize training when there are calculation discrepancies between
107+ # rollout and training engines, especially for small probabilities
108+ pg_truncfrac_pos , pg_truncfrac_neg = 0.0 , 0.0
109+ pg_losses_trunc = pg_losses_clip
110+
111+ # Add IS truncation for positive advantages
112+ if self .truncate_adv_pos_is :
113+ pg_losses_pos_trunc = - advantages * self .truncate_is_range_low
114+ pg_truncfrac_pos = masked_mean (
115+ torch .lt (pg_losses_pos_trunc , pg_losses_trunc ) * (advantages > 0 ).float (),
116+ action_mask ,
117+ )
118+ pg_losses_pos = torch .minimum (pg_losses_trunc , pg_losses_pos_trunc )
119+ pg_losses_trunc = torch .where (advantages > 0 , pg_losses_pos , pg_losses_trunc )
120+
121+ # Add IS truncation for negative advantages
122+ if self .truncate_adv_neg_is :
123+ pg_losses_neg_trunc = - advantages * self .truncate_is_range_high
124+ pg_truncfrac_neg = masked_mean (
125+ torch .lt (pg_losses_neg_trunc , pg_losses_trunc ) * (advantages < 0 ).float (),
126+ action_mask ,
127+ )
128+ pg_losses_neg = torch .minimum (pg_losses_trunc , pg_losses_neg_trunc )
129+ pg_losses_trunc = torch .where (advantages < 0 , pg_losses_neg , pg_losses_trunc )
98130
99131 pg_loss = masked_loss (
100- torch . max ( pg_losses , pg_losses2 ) , action_mask , loss_agg_mode = self .loss_agg_mode
132+ pg_losses_trunc , action_mask , loss_agg_mode = self .loss_agg_mode
101133 )
102- pg_clipfrac = masked_mean (torch .gt (pg_losses2 , pg_losses ).float (), action_mask )
103134 metrics = {
104135 "pg_clipfrac" : pg_clipfrac .detach ().item (),
105136 "ppo_kl" : ppo_kl .detach ().item (),
106137 "pg_loss" : pg_loss .detach ().item (),
107138 }
108139
109140 # Add IS truncation metrics if enabled
110- if self .truncate_large_is :
111- metrics ["is_truncate_frac" ] = is_truncate_frac .detach ().item ()
141+ if self .truncate_adv_pos_is :
142+ metrics ["is_truncate_frac_pos" ] = pg_truncfrac_pos .detach ().item ()
143+ if self .truncate_adv_neg_is :
144+ metrics ["is_truncate_frac_neg" ] = pg_truncfrac_neg .detach ().item ()
112145
113146 return pg_loss , metrics
114147
@@ -117,7 +150,8 @@ def default_args(cls) -> Dict:
117150 return {
118151 "clip_range" : 0.2 ,
119152 "loss_agg_mode" : "token-mean" ,
120- "truncate_large_is" : False ,
153+ "truncate_adv_pos_is" : False ,
154+ "truncate_adv_neg_is" : False ,
121155 "truncate_is_range_low" : 0.0 ,
122156 "truncate_is_range_high" : 2.0 ,
123157 }
0 commit comments