|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +"""Test for policy loss functions""" |
| 3 | + |
| 4 | +import unittest |
| 5 | + |
| 6 | +import torch |
| 7 | +from verl import DataProto |
| 8 | + |
| 9 | +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN |
| 10 | + |
| 11 | + |
| 12 | +class VerlPolicyLossTest(unittest.TestCase): |
| 13 | + def setUp(self): |
| 14 | + seed = 42 |
| 15 | + torch.manual_seed(seed) |
| 16 | + torch.cuda.manual_seed(seed) |
| 17 | + torch.cuda.manual_seed_all(seed) |
| 18 | + torch.backends.cudnn.deterministic = True |
| 19 | + torch.backends.cudnn.benchmark = False |
| 20 | + |
| 21 | + shape = (5, 20) |
| 22 | + self.logprob = 2 * torch.rand(shape) - 1 |
| 23 | + self.input_data = DataProto.from_dict( |
| 24 | + { |
| 25 | + "old_log_probs": 2 * torch.rand(shape) - 1, |
| 26 | + "ref_log_prob": 2 * torch.rand(shape) - 1, |
| 27 | + "response_mask": torch.rand(shape) > 0.5, |
| 28 | + "advantages": 2 * torch.rand(shape) - 1, |
| 29 | + "is_expert_mask": torch.rand(shape[0]) > 0.5, |
| 30 | + } |
| 31 | + ) |
| 32 | + |
| 33 | + def test_ppo_policy_loss(self): |
| 34 | + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") |
| 35 | + policy_loss_fn_args = policy_loss_fn_cls.default_args() |
| 36 | + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) |
| 37 | + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) |
| 38 | + ppo_loss = torch.tensor(0.28560468554496765) |
| 39 | + pg_clipfrac = torch.tensor(0.3541666567325592) |
| 40 | + ppo_kl = torch.tensor(-0.21663446724414825) |
| 41 | + self.assertTrue(torch.allclose(loss, ppo_loss)) |
| 42 | + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) |
| 43 | + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) |
| 44 | + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss)) |
| 45 | + |
| 46 | + def test_sft_policy_loss(self): |
| 47 | + policy_loss_fn_cls = POLICY_LOSS_FN.get("sft") |
| 48 | + policy_loss_fn_args = policy_loss_fn_cls.default_args() |
| 49 | + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) |
| 50 | + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) |
| 51 | + sft_loss = torch.tensor(-0.07560186833143234) |
| 52 | + self.assertTrue(torch.allclose(loss, sft_loss)) |
| 53 | + self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss)) |
| 54 | + |
| 55 | + def test_dpo_policy_loss(self): |
| 56 | + policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo") |
| 57 | + policy_loss_fn_args = policy_loss_fn_cls.default_args() |
| 58 | + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) |
| 59 | + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) |
| 60 | + dpo_loss = torch.tensor(0.5406752228736877) |
| 61 | + chosen_reward = torch.tensor(0.7082431316375732) |
| 62 | + rejected_reward = torch.tensor(0.3757950782775879) |
| 63 | + accuracy_mean = torch.tensor(1.0) |
| 64 | + self.assertTrue(torch.allclose(loss, dpo_loss)) |
| 65 | + self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward)) |
| 66 | + self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward)) |
| 67 | + self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean)) |
| 68 | + self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss)) |
| 69 | + |
| 70 | + def test_opmd_policy_loss(self): |
| 71 | + policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd") |
| 72 | + policy_loss_fn_args = policy_loss_fn_cls.default_args() |
| 73 | + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) |
| 74 | + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) |
| 75 | + opmd_loss = torch.tensor(-0.009589947760105133) |
| 76 | + self.assertTrue(torch.allclose(loss, opmd_loss)) |
| 77 | + self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss)) |
| 78 | + |
| 79 | + def test_mix_policy_loss(self): |
| 80 | + policy_loss_fn_cls = POLICY_LOSS_FN.get("mix") |
| 81 | + policy_loss_fn_args = policy_loss_fn_cls.default_args() |
| 82 | + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) |
| 83 | + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) |
| 84 | + mix_loss = torch.tensor(0.6581965088844299) |
| 85 | + pg_clipfrac = torch.tensor(0.7777777910232544) |
| 86 | + ppo_kl = torch.tensor(-1.0737695693969727) |
| 87 | + pg_loss = torch.tensor(0.7236452102661133) |
| 88 | + sft_loss = torch.tensor(0.06915830634534359) |
| 89 | + self.assertTrue(torch.allclose(loss, mix_loss)) |
| 90 | + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac)) |
| 91 | + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl)) |
| 92 | + self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) |
| 93 | + self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) |
| 94 | + self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) |
0 commit comments