-
Notifications
You must be signed in to change notification settings - Fork 47
Refactor on select_keys
#84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
pan-x-c
merged 10 commits into
modelscope:algorithm_dev
from
chenyushuo:dev/refactor_on_select_keys
Jun 18, 2025
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
51d229f
Refactor on `select_keys`
chenyushuo 68c2997
Add `key_mapper`
chenyushuo 950be68
bug fix
chenyushuo f765afc
Merge branch 'algorithm_dev' of github.com:modelscope/Trinity-RFT int…
chenyushuo c10d414
fix mix_policy_loss
chenyushuo e279c75
doc fix
chenyushuo 18960d4
Doc fix
chenyushuo d2be8f5
bug fix && add unittest for policy_loss
chenyushuo ae96b5f
bug fix
chenyushuo 37386d0
doc fix
chenyushuo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """Test for policy loss functions""" | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from verl import DataProto | ||
|
|
||
| from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN | ||
|
|
||
|
|
||
| class VerlPolicyLossTest(unittest.TestCase): | ||
| def setUp(self): | ||
| seed = 42 | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
| torch.cuda.manual_seed_all(seed) | ||
| torch.backends.cudnn.deterministic = True | ||
| torch.backends.cudnn.benchmark = False | ||
|
|
||
| shape = (5, 20) | ||
| self.logprob = 2 * torch.rand(shape) - 1 | ||
| self.input_data = DataProto.from_dict( | ||
| { | ||
| "old_log_probs": 2 * torch.rand(shape) - 1, | ||
| "ref_log_prob": 2 * torch.rand(shape) - 1, | ||
| "response_mask": torch.rand(shape) > 0.5, | ||
| "advantages": 2 * torch.rand(shape) - 1, | ||
| "is_expert_mask": torch.rand(shape[0]) > 0.5, | ||
| } | ||
| ) | ||
|
|
||
| def test_ppo_policy_loss(self): | ||
| policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") | ||
| policy_loss_fn_args = policy_loss_fn_cls.default_args() | ||
| policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) | ||
| loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) | ||
| ppo_loss = torch.tensor(0.28560468554496765) | ||
| pg_clipfrac = torch.tensor(0.3541666567325592) | ||
| ppo_kl = torch.tensor(-0.21663446724414825) | ||
| self.assertTrue(torch.allclose(loss, ppo_loss)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss)) | ||
|
|
||
| def test_sft_policy_loss(self): | ||
| policy_loss_fn_cls = POLICY_LOSS_FN.get("sft") | ||
| policy_loss_fn_args = policy_loss_fn_cls.default_args() | ||
| policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) | ||
| loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) | ||
| sft_loss = torch.tensor(-0.07560186833143234) | ||
| self.assertTrue(torch.allclose(loss, sft_loss)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss)) | ||
|
|
||
| def test_dpo_policy_loss(self): | ||
| policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo") | ||
| policy_loss_fn_args = policy_loss_fn_cls.default_args() | ||
| policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) | ||
| loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) | ||
| dpo_loss = torch.tensor(0.5406752228736877) | ||
| chosen_reward = torch.tensor(0.7082431316375732) | ||
| rejected_reward = torch.tensor(0.3757950782775879) | ||
| accuracy_mean = torch.tensor(1.0) | ||
| self.assertTrue(torch.allclose(loss, dpo_loss)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss)) | ||
|
|
||
| def test_opmd_policy_loss(self): | ||
| policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd") | ||
| policy_loss_fn_args = policy_loss_fn_cls.default_args() | ||
| policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) | ||
| loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) | ||
| opmd_loss = torch.tensor(-0.009589947760105133) | ||
| self.assertTrue(torch.allclose(loss, opmd_loss)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss)) | ||
|
|
||
| def test_mix_policy_loss(self): | ||
| policy_loss_fn_cls = POLICY_LOSS_FN.get("mix") | ||
| policy_loss_fn_args = policy_loss_fn_cls.default_args() | ||
| policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) | ||
| loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) | ||
| mix_loss = torch.tensor(0.6581965088844299) | ||
| pg_clipfrac = torch.tensor(0.7777777910232544) | ||
| ppo_kl = torch.tensor(-1.0737695693969727) | ||
| pg_loss = torch.tensor(0.7236452102661133) | ||
| sft_loss = torch.tensor(0.06915830634534359) | ||
| self.assertTrue(torch.allclose(loss, mix_loss)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) | ||
| self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """Key Mapper""" | ||
|
|
||
| from typing import Dict | ||
|
|
||
|
|
||
| class KeyMapper: | ||
| def __init__(self, to_trinity_map: Dict[str, str]): | ||
| self.to_trinity_map = to_trinity_map | ||
| self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()} | ||
|
|
||
| def to_trinity(self, key: str) -> str: | ||
| return self.to_trinity_map.get(key, key) | ||
|
|
||
| def from_trinity(self, key: str) -> str: | ||
| return self.from_trinity_map.get(key, key) | ||
|
|
||
|
|
||
| ALL_MAPPERS = { | ||
| "verl": KeyMapper( | ||
| { | ||
| "log_prob": "logprob", | ||
| "old_log_probs": "old_logprob", | ||
| "ref_log_prob": "ref_logprob", | ||
| "response_mask": "action_mask", | ||
| "advantages": "advantages", | ||
| } | ||
| ), | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.