@@ -109,39 +109,60 @@ def test_mix_policy_loss(self):
109109 self .assertTrue (torch .allclose (torch .tensor (metrics ["expert/sft_loss" ]), sft_loss ))
110110 self .assertTrue (torch .allclose (torch .tensor (metrics ["loss" ]), mix_loss ))
111111
112- def test_ppo_policy_loss_with_truncate_is (self ):
112+ def test_ppo_policy_loss_with_truncate_adv_pos_is (self ):
113113 """Test PPO policy loss with truncate large IS enabled."""
114114 policy_loss_fn_cls = POLICY_LOSS_FN .get ("ppo" )
115115 policy_loss_fn_args = policy_loss_fn_cls .default_args ()
116- # Enable truncate large IS with default bounds [0.0, 2.0]
116+ # Truncate small IS when advantage is positive
117+ policy_loss_fn_args ["truncate_adv_neg_is" ] = False
117118 policy_loss_fn_args ["truncate_adv_pos_is" ] = True
119+ policy_loss_fn_args ["truncate_is_range_low" ] = 0.5
120+ policy_loss_fn = policy_loss_fn_cls (** policy_loss_fn_args )
121+ loss , metrics = policy_loss_fn (log_prob = self .logprob , ** self .input_data .batch )
122+
123+ # Expected values with IS truncation enabled when advantage is positive
124+ ppo_loss_truncated = torch .tensor (0.28531503677368164 )
125+ pg_clipfrac = torch .tensor (0.3541666567325592 )
126+ ppo_kl = torch .tensor (- 0.21663446724414825 )
127+ is_truncate_frac_pos_expected = torch .tensor (0.02083333395421505 )
128+
129+ self .assertTrue (torch .allclose (loss , ppo_loss_truncated ))
130+ self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_clipfrac" ]), pg_clipfrac ))
131+ self .assertTrue (torch .allclose (torch .tensor (metrics ["ppo_kl" ]), ppo_kl ))
132+ self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_loss" ]), ppo_loss_truncated ))
133+ # Check that IS truncation metric is present and has expected value
134+ self .assertIn ("is_truncate_frac_pos" , metrics )
135+ self .assertTrue (
136+ torch .allclose (torch .tensor (metrics ["is_truncate_frac_pos" ]), is_truncate_frac_pos_expected )
137+ )
138+ self .assertGreaterEqual (metrics ["is_truncate_frac_pos" ], 0.0 )
139+ self .assertLessEqual (metrics ["is_truncate_frac_pos" ], 1.0 )
140+
141+ def test_ppo_policy_loss_with_truncate_adv_neg_is (self ):
142+ """Test PPO policy loss with truncate large IS enabled."""
143+ policy_loss_fn_cls = POLICY_LOSS_FN .get ("ppo" )
144+ policy_loss_fn_args = policy_loss_fn_cls .default_args ()
145+ # truncate large IS when advantage is negative
146+ policy_loss_fn_args ["truncate_adv_pos_is" ] = False
118147 policy_loss_fn_args ["truncate_adv_neg_is" ] = True
119- policy_loss_fn_args ["truncate_is_range_low" ] = 0.0
120148 policy_loss_fn_args ["truncate_is_range_high" ] = 2.0
121149 policy_loss_fn = policy_loss_fn_cls (** policy_loss_fn_args )
122150 loss , metrics = policy_loss_fn (log_prob = self .logprob , ** self .input_data .batch )
123151
124- # Expected values with IS truncation enabled (range: [0.0, 2.0])
152+ # Expected values with IS truncation enabled when advantage is negative
125153 ppo_loss_truncated = torch .tensor (0.2230827361345291 )
126- pg_clipfrac_truncated = torch .tensor (0.3541666567325592 )
127- ppo_kl_truncated = torch .tensor (- 0.21663446724414825 )
128- is_truncate_frac_pos_expected = torch .tensor (0.0 )
154+ pg_clipfrac = torch .tensor (0.3541666567325592 )
155+ ppo_kl = torch .tensor (- 0.21663446724414825 )
129156 is_truncate_frac_neg_expected = torch .tensor (0.1041666641831398 )
130157
131158 self .assertTrue (torch .allclose (loss , ppo_loss_truncated ))
132- self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_clipfrac" ]), pg_clipfrac_truncated ))
133- self .assertTrue (torch .allclose (torch .tensor (metrics ["ppo_kl" ]), ppo_kl_truncated ))
159+ self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_clipfrac" ]), pg_clipfrac ))
160+ self .assertTrue (torch .allclose (torch .tensor (metrics ["ppo_kl" ]), ppo_kl ))
134161 self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_loss" ]), ppo_loss_truncated ))
135162 # Check that IS truncation metric is present and has expected value
136- self .assertIn ("is_truncate_frac_pos" , metrics )
137163 self .assertIn ("is_truncate_frac_neg" , metrics )
138- self .assertTrue (
139- torch .allclose (torch .tensor (metrics ["is_truncate_frac_pos" ]), is_truncate_frac_pos_expected )
140- )
141164 self .assertTrue (
142165 torch .allclose (torch .tensor (metrics ["is_truncate_frac_neg" ]), is_truncate_frac_neg_expected )
143166 )
144- self .assertGreaterEqual (metrics ["is_truncate_frac_pos" ], 0.0 )
145- self .assertLessEqual (metrics ["is_truncate_frac_pos" ], 1.0 )
146167 self .assertGreaterEqual (metrics ["is_truncate_frac_neg" ], 0.0 )
147168 self .assertLessEqual (metrics ["is_truncate_frac_neg" ], 1.0 )
0 commit comments