Skip to content

Commit 9e8b43e

Browse files
committed
add truncate_adv_both_is()
1 parent 1492a46 commit 9e8b43e

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,41 @@ def test_ppo_policy_loss_with_truncate_adv_neg_is(self):
166166
)
167167
self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0)
168168
self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0)
169+
170+
def test_ppo_policy_loss_with_truncate_adv_both_is(self):
171+
"""Test PPO policy loss with truncate large IS enabled."""
172+
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
173+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
174+
# truncate large IS when advantage is negative
175+
policy_loss_fn_args["truncate_adv_pos_is"] = True
176+
policy_loss_fn_args["truncate_is_range_low"] = 0.5
177+
policy_loss_fn_args["truncate_adv_neg_is"] = True
178+
policy_loss_fn_args["truncate_is_range_high"] = 2.0
179+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
180+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
181+
182+
# Expected values with IS truncation enabled when advantage is negative
183+
# ppo_loss_truncated = ppo_loss_adv_pos_truncated + ppo_loss_adv_neg_truncated - ppo_loss_untruncated
184+
ppo_loss_truncated = torch.tensor(0.2227930873632431)
185+
pg_clipfrac = torch.tensor(0.3541666567325592)
186+
ppo_kl = torch.tensor(-0.21663446724414825)
187+
is_truncate_frac_pos_expected = torch.tensor(0.02083333395421505)
188+
is_truncate_frac_neg_expected = torch.tensor(0.1041666641831398)
189+
190+
self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
191+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
192+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
193+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
194+
# Check that IS truncation metric is present and has expected value
195+
self.assertIn("is_truncate_frac_pos", metrics)
196+
self.assertTrue(
197+
torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected)
198+
)
199+
self.assertIn("is_truncate_frac_neg", metrics)
200+
self.assertTrue(
201+
torch.allclose(torch.tensor(metrics["is_truncate_frac_neg"]), is_truncate_frac_neg_expected)
202+
)
203+
self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0)
204+
self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0)
205+
self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0)
206+
self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0)

0 commit comments

Comments
 (0)