@@ -166,3 +166,41 @@ def test_ppo_policy_loss_with_truncate_adv_neg_is(self):
166166 )
167167 self .assertGreaterEqual (metrics ["is_truncate_frac_neg" ], 0.0 )
168168 self .assertLessEqual (metrics ["is_truncate_frac_neg" ], 1.0 )
169+
170+ def test_ppo_policy_loss_with_truncate_adv_both_is (self ):
171+ """Test PPO policy loss with truncate large IS enabled."""
172+ policy_loss_fn_cls = POLICY_LOSS_FN .get ("ppo" )
173+ policy_loss_fn_args = policy_loss_fn_cls .default_args ()
174+ # truncate large IS when advantage is negative
175+ policy_loss_fn_args ["truncate_adv_pos_is" ] = True
176+ policy_loss_fn_args ["truncate_is_range_low" ] = 0.5
177+ policy_loss_fn_args ["truncate_adv_neg_is" ] = True
178+ policy_loss_fn_args ["truncate_is_range_high" ] = 2.0
179+ policy_loss_fn = policy_loss_fn_cls (** policy_loss_fn_args )
180+ loss , metrics = policy_loss_fn (log_prob = self .logprob , ** self .input_data .batch )
181+
182+ # Expected values with IS truncation enabled when advantage is negative
183+ # ppo_loss_truncated = ppo_loss_adv_pos_truncated + ppo_loss_adv_pos_truncated - ppo_loss_untruncated
184+ ppo_loss_truncated = torch .tensor (0.2227930873632431 )
185+ pg_clipfrac = torch .tensor (0.3541666567325592 )
186+ ppo_kl = torch .tensor (- 0.21663446724414825 )
187+ is_truncate_frac_pos_expected = torch .tensor (0.02083333395421505 )
188+ is_truncate_frac_neg_expected = torch .tensor (0.1041666641831398 )
189+
190+ self .assertTrue (torch .allclose (loss , ppo_loss_truncated ))
191+ self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_clipfrac" ]), pg_clipfrac ))
192+ self .assertTrue (torch .allclose (torch .tensor (metrics ["ppo_kl" ]), ppo_kl ))
193+ self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_loss" ]), ppo_loss_truncated ))
194+ # Check that IS truncation metric is present and has expected value
195+ self .assertIn ("is_truncate_frac_pos" , metrics )
196+ self .assertTrue (
197+ torch .allclose (torch .tensor (metrics ["is_truncate_frac_pos" ]), is_truncate_frac_pos_expected )
198+ )
199+ self .assertIn ("is_truncate_frac_neg" , metrics )
200+ self .assertTrue (
201+ torch .allclose (torch .tensor (metrics ["is_truncate_frac_neg" ]), is_truncate_frac_neg_expected )
202+ )
203+ self .assertGreaterEqual (metrics ["is_truncate_frac_pos" ], 0.0 )
204+ self .assertLessEqual (metrics ["is_truncate_frac_pos" ], 1.0 )
205+ self .assertGreaterEqual (metrics ["is_truncate_frac_neg" ], 0.0 )
206+ self .assertLessEqual (metrics ["is_truncate_frac_neg" ], 1.0 )
0 commit comments