Skip to content

Commit 9d582e8

Browse files
authored
Add unittest && bug fix (#65)
1 parent fe217aa commit 9d582e8

File tree

14 files changed

+295
-72
lines changed

14 files changed

+295
-72
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ requires-python = ">=3.10"
2323
dependencies = [
2424
"verl==0.3.0.post1",
2525
"ray[default]>=2.45.0",
26-
"vllm>=0.8.5",
26+
"vllm==0.8.5.post1",
2727
"tensordict==0.6.2",
2828
"wandb",
2929
"omegaconf",

tests/template/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ model:
1818
max_prompt_tokens: 2048
1919
max_response_tokens: 2048
2020
cluster: # 2 for explorer, 2 for trainer
21-
node_num: 1
22-
gpu_per_node: 4
21+
node_num: 2
22+
gpu_per_node: 2
2323
buffer:
2424
total_epochs: 1
2525
batch_size: 4

tests/template/data/sft_for_gsm8k/sft.jsonl

Lines changed: 32 additions & 0 deletions
Large diffs are not rendered by default.

tests/tools.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
StorageConfig,
1414
load_config,
1515
)
16+
from trinity.common.constants import PromptType
1617

1718

1819
def get_template_config() -> Config:
@@ -59,6 +60,47 @@ def get_unittest_dataset_config(
5960
default_workflow_type="math_workflow",
6061
default_reward_fn_type="countdown_reward",
6162
)
63+
elif dataset_name == "gsm8k":
64+
return StorageConfig(
65+
name=dataset_name,
66+
path="openai/gsm8k",
67+
split=split,
68+
subset_name="main",
69+
format=FormatConfig(
70+
prompt_key="question",
71+
response_key="answer",
72+
),
73+
rollout_args=GenerationConfig(
74+
n=1,
75+
temperature=1.0,
76+
logprobs=0,
77+
),
78+
default_workflow_type="math_workflow",
79+
default_reward_fn_type="math_reward",
80+
)
81+
elif dataset_name == "sft_for_gsm8k":
82+
return StorageConfig(
83+
name=dataset_name,
84+
path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"),
85+
split="train",
86+
format=FormatConfig(
87+
prompt_type=PromptType.PLAINTEXT,
88+
prompt_key="prompt",
89+
response_key="response",
90+
),
91+
)
92+
elif dataset_name == "dpo":
93+
return StorageConfig(
94+
name=dataset_name,
95+
path="HumanLLMs/Human-Like-DPO-Dataset",
96+
split="train",
97+
format=FormatConfig(
98+
prompt_type=PromptType.PLAINTEXT,
99+
prompt_key="prompt",
100+
chosen_key="chosen",
101+
rejected_key="rejected",
102+
),
103+
)
62104
else:
63105
raise ValueError(f"Unknown dataset name: {dataset_name}")
64106

@@ -104,6 +146,11 @@ def metric_steps(self, metric_name: str) -> List[int]:
104146
raise ValueError(f"Metric '{metric_name}' does not exist.")
105147
return list(self._metrics[metric_name].keys())
106148

149+
def metric_values(self, metric_name: str) -> List:
150+
if not self.metric_exist(metric_name):
151+
raise ValueError(f"Metric '{metric_name}' does not exist.")
152+
return list(self._metrics[metric_name].values())
153+
107154
def metric_list(self, metric_prefix: str) -> List[str]:
108155
return [name for name in self._metrics if name.startswith(metric_prefix)]
109156

tests/trainer/trainer_test.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
get_template_config,
1515
get_unittest_dataset_config,
1616
)
17-
from trinity.cli.launcher import bench, both
18-
from trinity.common.constants import MonitorType, SyncMethod
17+
from trinity.cli.launcher import bench, both, train
18+
from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod
1919

2020

2121
class BaseTrainerCase(RayUnittestBase):
@@ -109,3 +109,107 @@ def test_trainer(self):
109109
def tearDown(self):
110110
# remove dir only when the test passed
111111
shutil.rmtree(self.config.checkpoint_job_dir)
112+
113+
114+
class TestTrainerGSM8K(BaseTrainerCase):
115+
def test_trainer(self):
116+
"""Test GSM8K."""
117+
# test both mode
118+
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
119+
self.config.algorithm.repeat_times = 4
120+
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
121+
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
122+
self.config.algorithm.advantage_fn_args = {}
123+
# self.config.buffer.batch_size = 96 # TODO: used for real testing
124+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
125+
self.config.check_and_update()
126+
self.config.trainer.trainer_config.trainer.total_training_steps = 4
127+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
128+
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
129+
both(self.config)
130+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
131+
rollout_metrics = parser.metric_list("rollout")
132+
self.assertTrue(len(rollout_metrics) > 0)
133+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
134+
actor_metrics = parser.metric_list("actor")
135+
self.assertTrue(len(actor_metrics) > 0)
136+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
137+
response_metrics = parser.metric_list("response_length")
138+
self.assertTrue(len(response_metrics) > 0)
139+
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
140+
# TODO: used for real testing
141+
# rewards = parser.metric_values("critic/rewards/mean")
142+
# self.assertTrue(0.4 < rewards[0] < 0.55)
143+
# self.assertTrue(0.4 < rewards[1] < 0.55)
144+
# self.assertTrue(0.6 < rewards[2] < 0.7)
145+
# self.assertTrue(0.6 < rewards[3] < 0.7)
146+
ray.shutdown(_exiting_interpreter=True)
147+
# check checkpoint
148+
149+
def tearDown(self):
150+
# remove dir only when the test passed
151+
shutil.rmtree(self.config.checkpoint_job_dir)
152+
153+
154+
class TestTrainerGSM8KWithSFT(BaseTrainerCase):
155+
def test_trainer(self):
156+
"""Test GSM8K With SFT."""
157+
# test both mode
158+
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
159+
self.config.algorithm.repeat_times = 4
160+
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
161+
self.config.algorithm.advantage_fn_args = {}
162+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
163+
self.config.buffer.trainer_input.sft_warmup_steps = 2
164+
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
165+
"sft_for_gsm8k"
166+
)
167+
self.config.check_and_update()
168+
self.config.trainer.trainer_config.trainer.total_training_steps = 4
169+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
170+
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
171+
both(self.config)
172+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
173+
rollout_metrics = parser.metric_list("rollout")
174+
self.assertTrue(len(rollout_metrics) > 0)
175+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
176+
actor_metrics = parser.metric_list("actor")
177+
self.assertTrue(len(actor_metrics) > 0)
178+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
179+
self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 4) # RFT
180+
response_metrics = parser.metric_list("response_length")
181+
self.assertTrue(len(response_metrics) > 0)
182+
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
183+
ray.shutdown(_exiting_interpreter=True)
184+
# check checkpoint
185+
186+
def tearDown(self):
187+
# remove dir only when the test passed
188+
shutil.rmtree(self.config.checkpoint_job_dir)
189+
190+
191+
class TestTrainerDPO(BaseTrainerCase):
192+
def test_trainer(self):
193+
"""Test DPO."""
194+
# test both mode
195+
self.config.mode = "train"
196+
self.config.algorithm.algorithm_type = AlgorithmType.DPO
197+
self.config.algorithm.policy_loss_fn = "dpo"
198+
self.config.algorithm.policy_loss_fn_args = {}
199+
# self.config.buffer.batch_size = 32
200+
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo")
201+
self.config.check_and_update()
202+
self.config.trainer.trainer_config.trainer.total_training_steps = 4
203+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
204+
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7
205+
train(self.config)
206+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
207+
actor_metrics = parser.metric_list("actor")
208+
self.assertTrue(len(actor_metrics) > 0)
209+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
210+
ray.shutdown(_exiting_interpreter=True)
211+
# check checkpoint
212+
213+
def tearDown(self):
214+
# remove dir only when the test passed
215+
shutil.rmtree(self.config.checkpoint_job_dir)

trinity/algorithm/policy_loss_fn/dpo_loss.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""DPO loss function."""
22

3-
from typing import Any, Dict, Tuple
3+
from typing import Dict, List, Tuple
44

55
import torch
66
import torch.nn.functional as F
@@ -19,13 +19,11 @@ def __init__(
1919
self.beta = beta
2020
self.label_smoothing = label_smoothing
2121

22-
def __call__(
22+
def __call__( # type: ignore
2323
self,
2424
logprob: torch.Tensor,
25-
old_logprob: torch.Tensor,
25+
ref_logprob: torch.Tensor,
2626
action_mask: torch.Tensor,
27-
advantages: torch.Tensor,
28-
experiences: Any,
2927
**kwargs,
3028
) -> Tuple[torch.Tensor, Dict]:
3129
chosen_logprob = logprob[::2]
@@ -35,8 +33,8 @@ def __call__(
3533
chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
3634
rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)
3735

38-
chosen_ref_logprob = old_logprob[::2]
39-
rejected_ref_logprob = old_logprob[1::2]
36+
chosen_ref_logprob = ref_logprob[::2]
37+
rejected_ref_logprob = ref_logprob[1::2]
4038
chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
4139
rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)
4240

@@ -65,3 +63,10 @@ def default_args(cls) -> Dict:
6563
"beta": 0.1,
6664
"label_smoothing": 0.0,
6765
}
66+
67+
@property
68+
def select_keys(self) -> List[str]:
69+
return [
70+
"ref_logprob",
71+
"action_mask",
72+
]

trinity/algorithm/policy_loss_fn/opmd_policy_loss.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

6-
from typing import Any, Dict, Tuple
6+
from typing import Dict, List, Tuple
77

88
import torch
99

@@ -16,13 +16,12 @@ class OPMDPolicyLossFn(PolicyLossFn):
1616
def __init__(self, tau: float = 1.0) -> None:
1717
self.tau = tau
1818

19-
def __call__(
19+
def __call__( # type: ignore
2020
self,
2121
logprob: torch.Tensor,
22-
old_logprob: torch.Tensor,
22+
old_logprob: torch.Tensor, # NOT USED!
2323
action_mask: torch.Tensor,
2424
advantages: torch.Tensor,
25-
experiences: Any,
2625
**kwargs,
2726
) -> Tuple[torch.Tensor, Dict]:
2827
pg_losses = -advantages * logprob
@@ -33,3 +32,11 @@ def __call__(
3332
@classmethod
3433
def default_args(cls) -> Dict:
3534
return {"tau": 1.0}
35+
36+
@property
37+
def select_keys(self) -> List[str]:
38+
return [
39+
"old_logprob",
40+
"action_mask",
41+
"advantages",
42+
]

trinity/algorithm/policy_loss_fn/policy_loss_fn.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, Tuple
2+
from typing import Dict, List, Tuple
33

44
import torch
55

@@ -17,10 +17,6 @@ class PolicyLossFn(ABC):
1717
def __call__(
1818
self,
1919
logprob: torch.Tensor,
20-
old_logprob: torch.Tensor,
21-
action_mask: torch.Tensor,
22-
advantages: torch.Tensor,
23-
experiences: Any,
2420
**kwargs,
2521
) -> Tuple[torch.Tensor, Dict]:
2622
"""
@@ -29,7 +25,6 @@ def __call__(
2925
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
3026
action_mask (`torch.Tensor`): The action mask.
3127
advantages (`torch.Tensor`): The advantages.
32-
experiences (`DataProto`): The input experiences.
3328
kwargs (`Dict`): The step-level parameters for calculating the policy loss.
3429
3530
Returns:
@@ -44,3 +39,11 @@ def default_args(cls) -> Dict:
4439
Returns:
4540
`Dict`: The default init arguments for the policy loss function.
4641
"""
42+
43+
@property
44+
@abstractmethod
45+
def select_keys(self) -> List[str]:
46+
"""
47+
Returns:
48+
`List[str]`: The keys to select from input data.
49+
"""

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

6-
from typing import Any, Dict, Optional, Tuple
6+
from typing import Dict, List, Optional, Tuple
77

88
import torch
99

@@ -30,13 +30,12 @@ def __init__(
3030
assert self.clip_range_low is not None, "clip_range_low must be specified."
3131
assert self.clip_range_high is not None, "clip_range_high must be specified."
3232

33-
def __call__(
33+
def __call__( # type: ignore
3434
self,
3535
logprob: torch.Tensor,
3636
old_logprob: torch.Tensor,
3737
action_mask: torch.Tensor,
3838
advantages: torch.Tensor,
39-
experiences: Any,
4039
**kwargs,
4140
) -> Tuple[torch.Tensor, Dict]:
4241
negative_approx_kl = logprob - old_logprob
@@ -62,3 +61,11 @@ def default_args(cls) -> Dict:
6261
return {
6362
"clip_range": 0.2,
6463
}
64+
65+
@property
66+
def select_keys(self) -> List[str]:
67+
return [
68+
"old_logprob",
69+
"action_mask",
70+
"advantages",
71+
]

trinity/algorithm/policy_loss_fn/sft_loss.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""SFT loss function."""
22

3-
from typing import Any, Dict, Tuple
3+
from typing import Dict, List, Tuple
44

55
import torch
66

@@ -13,13 +13,10 @@ class SFTLossFn(PolicyLossFn):
1313
def __init__(self, use_token_level_loss: bool = True) -> None:
1414
self.use_token_level_loss = use_token_level_loss
1515

16-
def __call__(
16+
def __call__( # type: ignore
1717
self,
1818
logprob: torch.Tensor,
19-
old_logprob: torch.Tensor,
2019
action_mask: torch.Tensor,
21-
advantages: torch.Tensor,
22-
experiences: Any,
2320
**kwargs,
2421
) -> Tuple[torch.Tensor, Dict]:
2522
if self.use_token_level_loss:
@@ -33,3 +30,7 @@ def default_args(cls):
3330
return {
3431
"use_token_level_loss": True,
3532
}
33+
34+
@property
35+
def select_keys(self) -> List[str]:
36+
return ["action_mask"]

0 commit comments

Comments
 (0)