Skip to content

Commit 6400b5d

Browse files
committed
udpate policy loss value test
1 parent 221b075 commit 6400b5d

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

tests/algorithm/policy_loss_test.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,28 @@ def test_mix_policy_loss(self):
117117
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
118118

119119
def test_ppo_policy_loss_with_sequence_masking(self):
120-
"""Test PPO policy loss with sequence masking enabled"""
121120
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
122121
policy_loss_fn_args = policy_loss_fn_cls.default_args()
123122
policy_loss_fn_args["enable_sequence_masking"] = True
124123
policy_loss_fn_args["delta"] = 0.1
125124
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
126125
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
127-
128-
# Test that sequence masking metrics are present
129-
self.assertIn("seq_mask/masked_tokens", metrics)
130-
self.assertIn("seq_mask/mean_sequence_kl", metrics)
131-
132-
# Test that masked_tokens is between 0 and 1
133-
self.assertGreaterEqual(metrics["seq_mask/masked_tokens"], 0.0)
134-
self.assertLessEqual(metrics["seq_mask/masked_tokens"], 1.0)
135-
136-
# Test that loss is different from non-masked version (if masking occurred)
137-
policy_loss_fn_no_mask = policy_loss_fn_cls(**policy_loss_fn_cls.default_args())
138-
loss_no_mask, _ = policy_loss_fn_no_mask(log_prob=self.logprob, **self.input_data.batch)
139-
140-
# Loss should be different if tokens were masked
141-
if metrics["seq_mask/masked_tokens"] > 0:
142-
self.assertFalse(torch.allclose(loss, loss_no_mask))
126+
ppo_loss_masked = torch.tensor(0.22175675630569458)
127+
pg_clipfrac = torch.tensor(0.3541666567325592)
128+
ppo_kl = torch.tensor(-0.21663446724414825)
129+
pg_clipfrac_lower = torch.tensor(0.0625)
130+
masked_tokens = torch.tensor(0.16666666666631944)
131+
mean_sequence_kl = torch.tensor(-0.21027061343193054)
132+
self.assertTrue(torch.allclose(loss, ppo_loss_masked))
133+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
134+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
135+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_masked))
136+
self.assertTrue(
137+
torch.allclose(torch.tensor(metrics["pg_clipfrac_lower"]), pg_clipfrac_lower)
138+
)
139+
self.assertTrue(
140+
torch.allclose(torch.tensor(metrics["seq_mask/masked_tokens"]), masked_tokens)
141+
)
142+
self.assertTrue(
143+
torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl)
144+
)

0 commit comments

Comments
 (0)