Skip to content

Commit 1492a46

Browse files
committed
add test policy loss for positive and negative adv
1 parent ec4d46c commit 1492a46

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,39 +109,60 @@ def test_mix_policy_loss(self):
109109
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
110110
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
111111

112-
def test_ppo_policy_loss_with_truncate_is(self):
112+
def test_ppo_policy_loss_with_truncate_adv_pos_is(self):
113113
"""Test PPO policy loss with truncate large IS enabled."""
114114
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
115115
policy_loss_fn_args = policy_loss_fn_cls.default_args()
116-
# Enable truncate large IS with default bounds [0.0, 2.0]
116+
# Truncate small IS when advantage is positive
117+
policy_loss_fn_args["truncate_adv_neg_is"] = False
117118
policy_loss_fn_args["truncate_adv_pos_is"] = True
119+
policy_loss_fn_args["truncate_is_range_low"] = 0.5
120+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
121+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
122+
123+
# Expected values with IS truncation enabled when advantage is positive
124+
ppo_loss_truncated = torch.tensor(0.28531503677368164)
125+
pg_clipfrac = torch.tensor(0.3541666567325592)
126+
ppo_kl = torch.tensor(-0.21663446724414825)
127+
is_truncate_frac_pos_expected = torch.tensor(0.02083333395421505)
128+
129+
self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
130+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
131+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
132+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
133+
# Check that IS truncation metric is present and has expected value
134+
self.assertIn("is_truncate_frac_pos", metrics)
135+
self.assertTrue(
136+
torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected)
137+
)
138+
self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0)
139+
self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0)
140+
141+
def test_ppo_policy_loss_with_truncate_adv_neg_is(self):
142+
"""Test PPO policy loss with truncate large IS enabled."""
143+
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
144+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
145+
# truncate large IS when advantage is negative
146+
policy_loss_fn_args["truncate_adv_pos_is"] = False
118147
policy_loss_fn_args["truncate_adv_neg_is"] = True
119-
policy_loss_fn_args["truncate_is_range_low"] = 0.0
120148
policy_loss_fn_args["truncate_is_range_high"] = 2.0
121149
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
122150
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
123151

124-
# Expected values with IS truncation enabled (range: [0.0, 2.0])
152+
# Expected values with IS truncation enabled when advantage is negative
125153
ppo_loss_truncated = torch.tensor(0.2230827361345291)
126-
pg_clipfrac_truncated = torch.tensor(0.3541666567325592)
127-
ppo_kl_truncated = torch.tensor(-0.21663446724414825)
128-
is_truncate_frac_pos_expected = torch.tensor(0.0)
154+
pg_clipfrac = torch.tensor(0.3541666567325592)
155+
ppo_kl = torch.tensor(-0.21663446724414825)
129156
is_truncate_frac_neg_expected = torch.tensor(0.1041666641831398)
130157

131158
self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
132-
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated))
133-
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated))
159+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
160+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
134161
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
135162
# Check that IS truncation metric is present and has expected value
136-
self.assertIn("is_truncate_frac_pos", metrics)
137163
self.assertIn("is_truncate_frac_neg", metrics)
138-
self.assertTrue(
139-
torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected)
140-
)
141164
self.assertTrue(
142165
torch.allclose(torch.tensor(metrics["is_truncate_frac_neg"]), is_truncate_frac_neg_expected)
143166
)
144-
self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0)
145-
self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0)
146167
self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0)
147168
self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0)

0 commit comments

Comments
 (0)