Skip to content

Commit 69ddbd0

Browse files
authored
Refactor on select_keys (#84)
1 parent b8d1faa commit 69ddbd0

File tree

15 files changed

+247
-100
lines changed

15 files changed

+247
-100
lines changed

docs/sphinx_doc/source/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
"sphinx.ext.napoleon",
2323
"sphinx.ext.autosectionlabel",
2424
"myst_parser",
25+
"sphinx.ext.mathjax",
2526
]
2627
source_suffix = {
2728
".rst": "restructuredtext",
2829
".md": "markdown",
2930
}
30-
myst_enable_extensions = ["colon_fence"]
31+
myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"]
3132

3233
# Prefix document path to section labels, otherwise autogenerated labels would
3334
# look like 'heading' rather than 'path/to/file:heading'

docs/sphinx_doc/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Welcome to Trinity-RFT's documentation!
2424
tutorial/example_data_functionalities.md
2525
tutorial/trinity_configs.md
2626
tutorial/trinity_programming_guide.md
27+
tutorial/example_mix_algo.md
2728

2829
.. toctree::
2930
:maxdepth: 1

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Integrate An New Algorithm
1+
# Integrate A New Algorithm
22

33

44
This guide introduces how to integrate a new algorithm to Trinity-RFT.
55
As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:
66

77
$$
88
\mathcal{J}_{\text{Mix}}(\theta) =
9-
\mathcal{J}_{\text{GRPO}}(\theta)
9+
(1-\mu) \mathcal{J}_{\text{GRPO}}(\theta)
1010
+
1111
\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'}
1212
\left[
@@ -170,6 +170,7 @@ We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_polic
170170
class MIXPolicyLossFn(PolicyLossFn):
171171
def __init__(
172172
self,
173+
backend: str = "verl",
173174
mu: float = 0.1,
174175
clip_range: Optional[float] = None,
175176
clip_range_low: Optional[float] = None,
@@ -183,6 +184,7 @@ class MIXPolicyLossFn(PolicyLossFn):
183184
read_batch_size_expert: Optional[int] = None,
184185
use_token_level_loss_in_sft: bool = True,
185186
) -> None:
187+
super().__init__(backend=backend)
186188
self.mu = mu
187189
self.use_dynamic_bsz = use_dynamic_bsz
188190
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
@@ -204,11 +206,9 @@ class MIXPolicyLossFn(PolicyLossFn):
204206
old_logprob: torch.Tensor,
205207
action_mask: torch.Tensor,
206208
advantages: torch.Tensor,
209+
is_expert_mask: torch.Tensor,
207210
**kwargs,
208211
) -> Tuple[torch.Tensor, Dict]:
209-
is_expert_mask = kwargs.get("is_expert_mask", None)
210-
if is_expert_mask is None:
211-
raise ValueError("is_expert_mask is required in MIX")
212212
assert (
213213
len(is_expert_mask) == logprob.shape[0]
214214
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
@@ -271,10 +271,6 @@ class MIXPolicyLossFn(PolicyLossFn):
271271
"mu": 0.1,
272272
"clip_range": 0.2,
273273
}
274-
275-
@property
276-
def select_keys(self) -> List[str]:
277-
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
278274
```
279275

280276
## Step 4: Run the Experiment

examples/mix_math/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ This example shows the usage of a new algorithm MIX on the MATH dataset.
44

55
For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md).
66

7-
The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).
7+
The config files are located in [`mix_math.yaml`](mix_math.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).

examples/mix_math/mix_math.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ cluster:
2727
buffer:
2828
total_epochs: 1
2929
batch_size: 40
30-
explore_batch_size: 36
3130
max_retry_times: 3
3231
max_retry_interval: 1
3332
explorer_input:
@@ -82,7 +81,7 @@ synchronizer:
8281
sync_timeout: 1200
8382
trainer:
8483
trainer_type: 'verl'
85-
trainer_config_path: 'examples/mix_math/train_math.yaml'
84+
trainer_config_path: 'examples/mix_math/train_mix_math.yaml'
8685
save_interval: 50
8786
monitor:
8887
monitor_type: wandb
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: utf-8 -*-
2+
"""Test for policy loss functions"""
3+
4+
import unittest
5+
6+
import torch
7+
from verl import DataProto
8+
9+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
10+
11+
12+
class VerlPolicyLossTest(unittest.TestCase):
13+
def setUp(self):
14+
seed = 42
15+
torch.manual_seed(seed)
16+
torch.cuda.manual_seed(seed)
17+
torch.cuda.manual_seed_all(seed)
18+
torch.backends.cudnn.deterministic = True
19+
torch.backends.cudnn.benchmark = False
20+
21+
shape = (5, 20)
22+
self.logprob = 2 * torch.rand(shape) - 1
23+
self.input_data = DataProto.from_dict(
24+
{
25+
"old_log_probs": 2 * torch.rand(shape) - 1,
26+
"ref_log_prob": 2 * torch.rand(shape) - 1,
27+
"response_mask": torch.rand(shape) > 0.5,
28+
"advantages": 2 * torch.rand(shape) - 1,
29+
"is_expert_mask": torch.rand(shape[0]) > 0.5,
30+
}
31+
)
32+
33+
def test_ppo_policy_loss(self):
34+
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
35+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
36+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
37+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
38+
ppo_loss = torch.tensor(0.28560468554496765)
39+
pg_clipfrac = torch.tensor(0.3541666567325592)
40+
ppo_kl = torch.tensor(-0.21663446724414825)
41+
self.assertTrue(torch.allclose(loss, ppo_loss))
42+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
43+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
44+
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss))
45+
46+
def test_sft_policy_loss(self):
47+
policy_loss_fn_cls = POLICY_LOSS_FN.get("sft")
48+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
49+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
50+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
51+
sft_loss = torch.tensor(-0.07560186833143234)
52+
self.assertTrue(torch.allclose(loss, sft_loss))
53+
self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss))
54+
55+
def test_dpo_policy_loss(self):
56+
policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo")
57+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
58+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
59+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
60+
dpo_loss = torch.tensor(0.5406752228736877)
61+
chosen_reward = torch.tensor(0.7082431316375732)
62+
rejected_reward = torch.tensor(0.3757950782775879)
63+
accuracy_mean = torch.tensor(1.0)
64+
self.assertTrue(torch.allclose(loss, dpo_loss))
65+
self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward))
66+
self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward))
67+
self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean))
68+
self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss))
69+
70+
def test_opmd_policy_loss(self):
71+
policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd")
72+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
73+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
74+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
75+
opmd_loss = torch.tensor(-0.009589947760105133)
76+
self.assertTrue(torch.allclose(loss, opmd_loss))
77+
self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss))
78+
79+
def test_mix_policy_loss(self):
80+
policy_loss_fn_cls = POLICY_LOSS_FN.get("mix")
81+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
82+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
83+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
84+
mix_loss = torch.tensor(0.6581965088844299)
85+
pg_clipfrac = torch.tensor(0.7777777910232544)
86+
ppo_kl = torch.tensor(-1.0737695693969727)
87+
pg_loss = torch.tensor(0.7236452102661133)
88+
sft_loss = torch.tensor(0.06915830634534359)
89+
self.assertTrue(torch.allclose(loss, mix_loss))
90+
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac))
91+
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl))
92+
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
93+
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
94+
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))

tests/common/config_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_all_examples_are_valid(self):
4747
config_path = os.path.join(example_dir, example_name, filename)
4848
try:
4949
config = load_config(config_path)
50+
config.checkpoint_root_dir = "./.cache/"
5051
config.check_and_update()
5152
except Exception as e:
5253
print(f"Error loading config {config_path}: {e}")

trinity/algorithm/key_mapper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# -*- coding: utf-8 -*-
2+
"""Key Mapper"""
3+
4+
from typing import Dict
5+
6+
7+
class KeyMapper:
8+
def __init__(self, to_trinity_map: Dict[str, str]):
9+
self.to_trinity_map = to_trinity_map
10+
self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()}
11+
12+
def to_trinity(self, key: str) -> str:
13+
return self.to_trinity_map.get(key, key)
14+
15+
def from_trinity(self, key: str) -> str:
16+
return self.from_trinity_map.get(key, key)
17+
18+
19+
ALL_MAPPERS = {
20+
"verl": KeyMapper(
21+
{
22+
"log_prob": "logprob",
23+
"old_log_probs": "old_logprob",
24+
"ref_log_prob": "ref_logprob",
25+
"response_mask": "action_mask",
26+
"advantages": "advantages",
27+
}
28+
),
29+
}

trinity/algorithm/policy_loss_fn/dpo_loss.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""DPO loss function."""
22

3-
from typing import Dict, List, Tuple
3+
from typing import Dict, Tuple
44

55
import torch
66
import torch.nn.functional as F
@@ -13,9 +13,11 @@
1313
class DPOLossFn(PolicyLossFn):
1414
def __init__(
1515
self,
16+
backend: str = "verl",
1617
beta: float = 0.1,
1718
label_smoothing: float = 0.0,
1819
) -> None:
20+
super().__init__(backend=backend)
1921
self.beta = beta
2022
self.label_smoothing = label_smoothing
2123

@@ -63,10 +65,3 @@ def default_args(cls) -> Dict:
6365
"beta": 0.1,
6466
"label_smoothing": 0.0,
6567
}
66-
67-
@property
68-
def select_keys(self) -> List[str]:
69-
return [
70-
"ref_logprob",
71-
"action_mask",
72-
]

trinity/algorithm/policy_loss_fn/mix_policy_loss.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Mix policy loss function."""
22

3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Dict, Optional, Tuple
44

55
import torch
66

@@ -26,27 +26,29 @@ class MIXPolicyLossFn(PolicyLossFn):
2626

2727
def __init__(
2828
self,
29+
backend: str = "verl",
2930
mu: float = 0.1,
3031
clip_range: Optional[float] = None,
3132
clip_range_low: Optional[float] = None,
3233
clip_range_high: Optional[float] = None,
3334
use_dynamic_bsz: Optional[bool] = None,
34-
repeat_times: Optional[int] = None,
35-
ppo_mini_batch_size: Optional[int] = None,
36-
ppo_micro_batch_size_per_gpu: Optional[int] = None,
37-
ngpus_trainer: Optional[int] = None,
38-
read_batch_size_usual: Optional[int] = None,
39-
read_batch_size_expert: Optional[int] = None,
35+
repeat_times: int = 1,
36+
ppo_mini_batch_size: int = 1,
37+
ppo_micro_batch_size_per_gpu: int = 1,
38+
ngpus_trainer: int = 1,
39+
read_batch_size_usual: int = 1,
40+
read_batch_size_expert: int = 1,
4041
use_token_level_loss_in_sft: bool = True,
4142
) -> None:
43+
super().__init__(backend=backend)
4244
self.mu = mu
4345
self.use_dynamic_bsz = use_dynamic_bsz
44-
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
46+
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer
4547
self.gradient_accumulation = (
46-
ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore
48+
ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu
4749
)
48-
self.read_batch_size_usual = read_batch_size_usual
49-
self.read_batch_size_expert = read_batch_size_expert
50+
self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer
51+
self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer
5052
self.grpo_loss_fn = PPOPolicyLossFn(
5153
clip_range=clip_range,
5254
clip_range_low=clip_range_low,
@@ -60,11 +62,9 @@ def __call__( # type: ignore
6062
old_logprob: torch.Tensor,
6163
action_mask: torch.Tensor,
6264
advantages: torch.Tensor,
65+
is_expert_mask: torch.Tensor,
6366
**kwargs,
6467
) -> Tuple[torch.Tensor, Dict]:
65-
is_expert_mask = kwargs.get("is_expert_mask", None)
66-
if is_expert_mask is None:
67-
raise ValueError("is_expert_mask is required in MIX")
6868
assert (
6969
len(is_expert_mask) == logprob.shape[0]
7070
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
@@ -127,7 +127,3 @@ def default_args(cls) -> Dict:
127127
"mu": 0.1,
128128
"clip_range": 0.2,
129129
}
130-
131-
@property
132-
def select_keys(self) -> List[str]:
133-
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]

0 commit comments

Comments
 (0)