88import torch
99
1010from trinity .algorithm .policy_loss_fn .policy_loss_fn import POLICY_LOSS_FN , PolicyLossFn
11- from trinity .algorithm .utils import masked_loss , masked_mean
11+ from trinity .algorithm .utils import aggregate_loss , masked_mean
1212
1313
1414@POLICY_LOSS_FN .register_module ("ppo" )
@@ -19,6 +19,7 @@ def __init__(
1919 clip_range : Optional [float ] = None ,
2020 clip_range_low : Optional [float ] = None ,
2121 clip_range_high : Optional [float ] = None ,
22+ clip_ratio_c : float = 3.0 ,
2223 loss_agg_mode : Optional [str ] = "token-mean" ,
2324 ) -> None :
2425 super ().__init__ (backend = backend )
@@ -30,6 +31,8 @@ def __init__(
3031 self .clip_range_high = clip_range
3132 else :
3233 self .clip_range_high = clip_range_high
34+ self .clip_ratio_c = clip_ratio_c
35+ assert clip_ratio_c > 1.0 , "clip_ratio_c must be greater than 1.0."
3336 assert self .clip_range_low is not None , "clip_range_low must be specified."
3437 assert self .clip_range_high is not None , "clip_range_high must be specified."
3538 self .loss_agg_mode = loss_agg_mode
@@ -43,20 +46,30 @@ def __call__( # type: ignore
4346 ** kwargs ,
4447 ) -> Tuple [torch .Tensor , Dict ]:
4548 negative_approx_kl = logprob - old_logprob
49+ # Clamp negative_approx_kl for stability
50+ negative_approx_kl = torch .clamp (negative_approx_kl , min = - 20.0 , max = 20.0 )
4651 ratio = torch .exp (negative_approx_kl )
4752 ppo_kl = masked_mean (- negative_approx_kl , action_mask )
4853
49- pg_losses = - advantages * ratio
54+ pg_losses1 = - advantages * ratio
5055 pg_losses2 = - advantages * torch .clamp (
5156 ratio , 1.0 - self .clip_range_low , 1.0 + self .clip_range_high # type: ignore
5257 )
5358
54- pg_loss = masked_loss (
55- torch .max (pg_losses , pg_losses2 ), action_mask , loss_agg_mode = self .loss_agg_mode
59+ clip_pg_losses1 = torch .maximum (pg_losses1 , pg_losses2 )
60+
61+ pg_clip_frac = masked_mean (torch .gt (pg_losses2 , pg_losses1 ).float (), action_mask )
62+
63+ pg_losses3 = - advantages * self .clip_ratio_c
64+ clip_pg_losses2 = torch .min (pg_losses3 , clip_pg_losses1 )
65+ pg_clipfrac_lower = masked_mean (
66+ torch .gt (clip_pg_losses1 , pg_losses3 ) * (advantages < 0 ).float (), action_mask
5667 )
57- pg_clipfrac = masked_mean (torch .gt (pg_losses2 , pg_losses ).float (), action_mask )
68+ pg_losses = torch .where (advantages < 0 , clip_pg_losses2 , clip_pg_losses1 )
69+ pg_loss = aggregate_loss (pg_losses , action_mask , loss_agg_mode = self .loss_agg_mode )
5870 metrics = {
59- "pg_clipfrac" : pg_clipfrac .detach ().item (),
71+ "pg_clipfrac" : pg_clip_frac .detach ().item (),
72+ "pg_clipfrac_lower" : pg_clipfrac_lower .detach ().item (),
6073 "ppo_kl" : ppo_kl .detach ().item (),
6174 "pg_loss" : pg_loss .detach ().item (),
6275 }
@@ -66,5 +79,6 @@ def __call__( # type: ignore
6679 def default_args (cls ) -> Dict :
6780 return {
6881 "clip_range" : 0.2 ,
82+ "clip_ratio_c" : 3.0 ,
6983 "loss_agg_mode" : "token-mean" ,
7084 }
0 commit comments