@@ -20,7 +20,29 @@ def __init__(
2020 clip_range_low : Optional [float ] = None ,
2121 clip_range_high : Optional [float ] = None ,
2222 loss_agg_mode : Optional [str ] = "token-mean" ,
23+ truncate_adv_pos_is : bool = False ,
24+ truncate_adv_neg_is : bool = False ,
25+ truncate_is_range_low : Optional [float ] = 0.0 ,
26+ truncate_is_range_high : Optional [float ] = 2.0 ,
2327 ) -> None :
28+ """
29+ Initialize PPO policy loss function.
30+
31+ Args:
32+ backend: Backend framework (default: "verl")
33+ clip_range: Symmetric clipping range for PPO
34+ clip_range_low: Lower bound for clipping (1.0 - clip_range_low)
35+ clip_range_high: Upper bound for clipping (1.0 + clip_range_high)
36+ loss_agg_mode: Loss aggregation mode (default: "token-mean")
37+ truncate_adv_pos_is: Whether to truncate large importance sampling ratios
38+ when advantage is positive to handle calculation discrepancies between
39+ rollout and training engines
40+ truncate_adv_neg_is: Whether to truncate large importance sampling ratios
41+ when advantage is negative to handle calculation discrepancies between
42+ rollout and training engines
43+ truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0)
44+ truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0)
45+ """
2446 super ().__init__ (backend = backend )
2547 if clip_range_low is None :
2648 self .clip_range_low = clip_range
@@ -34,6 +56,32 @@ def __init__(
3456 assert self .clip_range_high is not None , "clip_range_high must be specified."
3557 self .loss_agg_mode = loss_agg_mode
3658
59+ # Truncate large IS configuration
60+ self .truncate_adv_pos_is = truncate_adv_pos_is
61+ self .truncate_adv_neg_is = truncate_adv_neg_is
62+ if truncate_adv_pos_is :
63+ self .truncate_is_range_low = truncate_is_range_low
64+ assert (
65+ self .truncate_is_range_low is not None
66+ ), "truncate_is_range_low must be specified."
67+ assert (
68+ self .truncate_is_range_low >= 0.0
69+ ), "truncate_is_range_low must be non-negative."
70+ assert (self .truncate_is_range_low < 1.0 - self .clip_range_low
71+ ), "truncate_is_range_low must be less than 1.0 - clip_range_low."
72+ if truncate_adv_neg_is :
73+ self .truncate_is_range_high = truncate_is_range_high
74+ assert (
75+ self .truncate_is_range_high is not None
76+ ), "truncate_is_range_high must be specified."
77+ assert (
78+ self .truncate_is_range_high > 1.0 + self .clip_range_high
79+ ), "truncate_is_range_high must be greater than clip_range_high + 1.0."
80+ if truncate_adv_pos_is and truncate_adv_neg_is :
81+ assert (
82+ self .truncate_is_range_high > self .truncate_is_range_low
83+ ), "truncate_is_range_high must be greater than truncate_is_range_low."
84+
3785 def __call__ ( # type: ignore
3886 self ,
3987 logprob : torch .Tensor ,
@@ -46,25 +94,64 @@ def __call__( # type: ignore
4694 ratio = torch .exp (negative_approx_kl )
4795 ppo_kl = masked_mean (- negative_approx_kl , action_mask )
4896
49- pg_losses = - advantages * ratio
97+ # First clipping by clip_range, and calculate pg_clipfrac
98+ pg_losses1 = - advantages * ratio
5099 pg_losses2 = - advantages * torch .clamp (
51100 ratio , 1.0 - self .clip_range_low , 1.0 + self .clip_range_high # type: ignore
52101 )
102+ pg_losses_clip = torch .maximum (pg_losses1 , pg_losses2 )
103+ pg_clipfrac = masked_mean (torch .gt (pg_losses2 , pg_losses1 ).float (), action_mask )
104+
105+ # After clipped by clip_range, further truncate IS ratios if enabled
106+ # This helps stabilize training when there are calculation discrepancies between
107+ # rollout and training engines, especially for small probabilities
108+ pg_truncfrac_pos , pg_truncfrac_neg = 0.0 , 0.0
109+ pg_losses_trunc = pg_losses_clip
110+
111+ # Add IS truncation for positive advantages
112+ if self .truncate_adv_pos_is :
113+ pg_losses_pos_trunc = - advantages * self .truncate_is_range_low
114+ pg_truncfrac_pos = masked_mean (
115+ torch .lt (pg_losses_pos_trunc , pg_losses_trunc ) * (advantages > 0 ).float (),
116+ action_mask ,
117+ )
118+ pg_losses_pos = torch .minimum (pg_losses_trunc , pg_losses_pos_trunc )
119+ pg_losses_trunc = torch .where (advantages > 0 , pg_losses_pos , pg_losses_trunc )
120+
121+ # Add IS truncation for negative advantages
122+ if self .truncate_adv_neg_is :
123+ pg_losses_neg_trunc = - advantages * self .truncate_is_range_high
124+ pg_truncfrac_neg = masked_mean (
125+ torch .lt (pg_losses_neg_trunc , pg_losses_trunc ) * (advantages < 0 ).float (),
126+ action_mask ,
127+ )
128+ pg_losses_neg = torch .minimum (pg_losses_trunc , pg_losses_neg_trunc )
129+ pg_losses_trunc = torch .where (advantages < 0 , pg_losses_neg , pg_losses_trunc )
53130
54131 pg_loss = masked_loss (
55- torch . max ( pg_losses , pg_losses2 ) , action_mask , loss_agg_mode = self .loss_agg_mode
132+ pg_losses_trunc , action_mask , loss_agg_mode = self .loss_agg_mode
56133 )
57- pg_clipfrac = masked_mean (torch .gt (pg_losses2 , pg_losses ).float (), action_mask )
58134 metrics = {
59135 "pg_clipfrac" : pg_clipfrac .detach ().item (),
60136 "ppo_kl" : ppo_kl .detach ().item (),
61137 "pg_loss" : pg_loss .detach ().item (),
62138 }
139+
140+ # Add IS truncation metrics if enabled
141+ if self .truncate_adv_pos_is :
142+ metrics ["is_truncate_frac_pos" ] = pg_truncfrac_pos .detach ().item ()
143+ if self .truncate_adv_neg_is :
144+ metrics ["is_truncate_frac_neg" ] = pg_truncfrac_neg .detach ().item ()
145+
63146 return pg_loss , metrics
64147
65148 @classmethod
66149 def default_args (cls ) -> Dict :
67150 return {
68151 "clip_range" : 0.2 ,
69152 "loss_agg_mode" : "token-mean" ,
153+ "truncate_adv_pos_is" : False ,
154+ "truncate_adv_neg_is" : False ,
155+ "truncate_is_range_low" : 0.0 ,
156+ "truncate_is_range_high" : 2.0 ,
70157 }
0 commit comments