Skip to content

Commit 70c21ac

Browse files
committed
Add unittest && bug fix
1 parent fe217aa commit 70c21ac

File tree

13 files changed

+305
-89
lines changed

13 files changed

+305
-89
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ requires-python = ">=3.10"
2323
dependencies = [
2424
"verl==0.3.0.post1",
2525
"ray[default]>=2.45.0",
26-
"vllm>=0.8.5",
26+
"vllm==0.8.5.post1",
2727
"tensordict==0.6.2",
2828
"wandb",
2929
"omegaconf",

tests/template/data/sft_for_gsm8k/sft.jsonl

Lines changed: 32 additions & 0 deletions
Large diffs are not rendered by default.

tests/tools.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
StorageConfig,
1414
load_config,
1515
)
16+
from trinity.common.constants import PromptType
1617

1718

1819
def get_template_config() -> Config:
@@ -59,6 +60,47 @@ def get_unittest_dataset_config(
5960
default_workflow_type="math_workflow",
6061
default_reward_fn_type="countdown_reward",
6162
)
63+
elif dataset_name == "gsm8k":
64+
return StorageConfig(
65+
name=dataset_name,
66+
path="openai/gsm8k",
67+
split=split,
68+
subset_name="main",
69+
format=FormatConfig(
70+
prompt_key="question",
71+
response_key="answer",
72+
),
73+
rollout_args=GenerationConfig(
74+
n=1,
75+
temperature=1.0,
76+
logprobs=0,
77+
),
78+
default_workflow_type="math_workflow",
79+
default_reward_fn_type="math_reward",
80+
)
81+
elif dataset_name == "sft_for_gsm8k":
82+
return StorageConfig(
83+
name=dataset_name,
84+
path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"),
85+
split="train",
86+
format=FormatConfig(
87+
prompt_type=PromptType.PLAINTEXT,
88+
prompt_key="prompt",
89+
response_key="response",
90+
),
91+
)
92+
elif dataset_name == "dpo":
93+
return StorageConfig(
94+
name=dataset_name,
95+
path="HumanLLMs/Human-Like-DPO-Dataset",
96+
split="train",
97+
format=FormatConfig(
98+
prompt_type=PromptType.PLAINTEXT,
99+
prompt_key="prompt",
100+
chosen_key="chosen",
101+
rejected_key="rejected",
102+
),
103+
)
62104
else:
63105
raise ValueError(f"Unknown dataset name: {dataset_name}")
64106

@@ -104,6 +146,11 @@ def metric_steps(self, metric_name: str) -> List[int]:
104146
raise ValueError(f"Metric '{metric_name}' does not exist.")
105147
return list(self._metrics[metric_name].keys())
106148

149+
def metric_values(self, metric_name: str) -> List:
150+
if not self.metric_exist(metric_name):
151+
raise ValueError(f"Metric '{metric_name}' does not exist.")
152+
return list(self._metrics[metric_name].values())
153+
107154
def metric_list(self, metric_prefix: str) -> List[str]:
108155
return [name for name in self._metrics if name.startswith(metric_prefix)]
109156

tests/trainer/trainer_test.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
get_template_config,
1515
get_unittest_dataset_config,
1616
)
17-
from trinity.cli.launcher import bench, both
18-
from trinity.common.constants import MonitorType, SyncMethod
17+
from trinity.cli.launcher import bench, both, train
18+
from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod
1919

2020

2121
class BaseTrainerCase(RayUnittestBase):
@@ -109,3 +109,106 @@ def test_trainer(self):
109109
def tearDown(self):
110110
# remove dir only when the test passed
111111
shutil.rmtree(self.config.checkpoint_job_dir)
112+
113+
114+
class TestTrainerGSM8K(BaseTrainerCase):
115+
def test_trainer(self):
116+
"""Test GSM8K."""
117+
# test both mode
118+
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
119+
self.config.algorithm.repeat_times = 8
120+
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
121+
self.config.algorithm.advantage_fn_args = {}
122+
# self.config.buffer.batch_size = 96 # TODO: used for real testing
123+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
124+
self.config.check_and_update()
125+
self.config.trainer.trainer_config.trainer.total_training_steps = 4
126+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
127+
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
128+
both(self.config)
129+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
130+
rollout_metrics = parser.metric_list("rollout")
131+
self.assertTrue(len(rollout_metrics) > 0)
132+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
133+
actor_metrics = parser.metric_list("actor")
134+
self.assertTrue(len(actor_metrics) > 0)
135+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
136+
response_metrics = parser.metric_list("response_length")
137+
self.assertTrue(len(response_metrics) > 0)
138+
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
139+
# TODO: used for real testing
140+
# rewards = parser.metric_values("critic/rewards/mean")
141+
# self.assertTrue(0.4 < rewards[0] < 0.55)
142+
# self.assertTrue(0.4 < rewards[1] < 0.55)
143+
# self.assertTrue(0.6 < rewards[2] < 0.7)
144+
# self.assertTrue(0.6 < rewards[3] < 0.7)
145+
ray.shutdown(_exiting_interpreter=True)
146+
# check checkpoint
147+
148+
def tearDown(self):
149+
# remove dir only when the test passed
150+
shutil.rmtree(self.config.checkpoint_job_dir)
151+
152+
153+
class TestTrainerGSM8KWithSFT(BaseTrainerCase):
154+
def test_trainer(self):
155+
"""Test GSM8K With SFT."""
156+
# test both mode
157+
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
158+
self.config.algorithm.repeat_times = 8
159+
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
160+
self.config.algorithm.advantage_fn_args = {}
161+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
162+
self.config.buffer.trainer_input.sft_warmup_steps = 2
163+
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
164+
"sft_for_gsm8k"
165+
)
166+
self.config.check_and_update()
167+
self.config.trainer.trainer_config.trainer.total_training_steps = 4
168+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
169+
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
170+
both(self.config)
171+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
172+
rollout_metrics = parser.metric_list("rollout")
173+
self.assertTrue(len(rollout_metrics) > 0)
174+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
175+
actor_metrics = parser.metric_list("actor")
176+
self.assertTrue(len(actor_metrics) > 0)
177+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
178+
self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 4) # RFT
179+
response_metrics = parser.metric_list("response_length")
180+
self.assertTrue(len(response_metrics) > 0)
181+
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
182+
ray.shutdown(_exiting_interpreter=True)
183+
# check checkpoint
184+
185+
def tearDown(self):
186+
# remove dir only when the test passed
187+
shutil.rmtree(self.config.checkpoint_job_dir)
188+
189+
190+
class TestTrainerDPO(BaseTrainerCase):
191+
def test_trainer(self):
192+
"""Test DPO."""
193+
# test both mode
194+
self.config.mode = "train"
195+
self.config.algorithm.algorithm_type = AlgorithmType.DPO
196+
self.config.algorithm.policy_loss_fn = "dpo"
197+
self.config.algorithm.policy_loss_fn_args = {}
198+
# self.config.buffer.batch_size = 32
199+
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo")
200+
self.config.check_and_update()
201+
self.config.trainer.trainer_config.trainer.total_training_steps = 4
202+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
203+
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7
204+
train(self.config)
205+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
206+
actor_metrics = parser.metric_list("actor")
207+
self.assertTrue(len(actor_metrics) > 0)
208+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
209+
ray.shutdown(_exiting_interpreter=True)
210+
# check checkpoint
211+
212+
def tearDown(self):
213+
# remove dir only when the test passed
214+
shutil.rmtree(self.config.checkpoint_job_dir)

trinity/algorithm/policy_loss_fn/dpo_loss.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""DPO loss function."""
22

3-
from typing import Any, Dict, Tuple
3+
from typing import Dict, List, Tuple
44

55
import torch
66
import torch.nn.functional as F
@@ -22,21 +22,19 @@ def __init__(
2222
def __call__(
2323
self,
2424
logprob: torch.Tensor,
25-
old_logprob: torch.Tensor,
26-
action_mask: torch.Tensor,
27-
advantages: torch.Tensor,
28-
experiences: Any,
25+
ref_log_prob: torch.Tensor,
26+
response_mask: torch.Tensor,
2927
**kwargs,
3028
) -> Tuple[torch.Tensor, Dict]:
3129
chosen_logprob = logprob[::2]
3230
rejected_logprob = logprob[1::2]
33-
chosen_mask = action_mask[::2]
34-
rejected_mask = action_mask[1::2]
31+
chosen_mask = response_mask[::2]
32+
rejected_mask = response_mask[1::2]
3533
chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
3634
rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)
3735

38-
chosen_ref_logprob = old_logprob[::2]
39-
rejected_ref_logprob = old_logprob[1::2]
36+
chosen_ref_logprob = ref_log_prob[::2]
37+
rejected_ref_logprob = ref_log_prob[1::2]
4038
chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
4139
rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)
4240

@@ -65,3 +63,14 @@ def default_args(cls) -> Dict:
6563
"beta": 0.1,
6664
"label_smoothing": 0.0,
6765
}
66+
67+
@property
68+
def select_keys(self) -> List[str]:
69+
return [
70+
"attention_mask",
71+
"input_ids",
72+
"position_ids",
73+
"response_mask",
74+
"responses",
75+
"ref_log_prob",
76+
]

trinity/algorithm/policy_loss_fn/opmd_policy_loss.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

6-
from typing import Any, Dict, Tuple
6+
from typing import Dict, List, Tuple
77

88
import torch
99

@@ -19,17 +19,28 @@ def __init__(self, tau: float = 1.0) -> None:
1919
def __call__(
2020
self,
2121
logprob: torch.Tensor,
22-
old_logprob: torch.Tensor,
23-
action_mask: torch.Tensor,
22+
old_log_probs: torch.Tensor,
23+
response_mask: torch.Tensor,
2424
advantages: torch.Tensor,
25-
experiences: Any,
2625
**kwargs,
2726
) -> Tuple[torch.Tensor, Dict]:
2827
pg_losses = -advantages * logprob
29-
opmd_loss = masked_mean(pg_losses, action_mask)
28+
opmd_loss = masked_mean(pg_losses, response_mask)
3029
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
3130
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
3231

3332
@classmethod
3433
def default_args(cls) -> Dict:
3534
return {"tau": 1.0}
35+
36+
@property
37+
def select_keys(self) -> List[str]:
38+
return [
39+
"responses",
40+
"input_ids",
41+
"attention_mask",
42+
"position_ids",
43+
"old_log_probs",
44+
"advantages",
45+
"response_mask",
46+
]

trinity/algorithm/policy_loss_fn/policy_loss_fn.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, Tuple
2+
from typing import Dict, List, Tuple
33

44
import torch
55

@@ -16,20 +16,14 @@ class PolicyLossFn(ABC):
1616
@abstractmethod
1717
def __call__(
1818
self,
19-
logprob: torch.Tensor,
20-
old_logprob: torch.Tensor,
21-
action_mask: torch.Tensor,
22-
advantages: torch.Tensor,
23-
experiences: Any,
2419
**kwargs,
2520
) -> Tuple[torch.Tensor, Dict]:
2621
"""
2722
Args:
2823
logprob (`torch.Tensor`): The log probability generated by the policy model.
29-
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
30-
action_mask (`torch.Tensor`): The action mask.
24+
old_log_probs (`torch.Tensor`): The log probability generated by the reference model.
25+
response_mask (`torch.Tensor`): The response mask.
3126
advantages (`torch.Tensor`): The advantages.
32-
experiences (`DataProto`): The input experiences.
3327
kwargs (`Dict`): The step-level parameters for calculating the policy loss.
3428
3529
Returns:
@@ -44,3 +38,10 @@ def default_args(cls) -> Dict:
4438
Returns:
4539
`Dict`: The default init arguments for the policy loss function.
4640
"""
41+
42+
@property
43+
def select_keys(self) -> List[str]:
44+
"""
45+
Returns:
46+
`List[str]`: The keys to select from input data.
47+
"""

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
44
"""
55

6-
from typing import Any, Dict, Optional, Tuple
6+
from typing import Dict, List, Optional, Tuple
77

88
import torch
99

@@ -33,23 +33,22 @@ def __init__(
3333
def __call__(
3434
self,
3535
logprob: torch.Tensor,
36-
old_logprob: torch.Tensor,
37-
action_mask: torch.Tensor,
36+
old_log_probs: torch.Tensor,
37+
response_mask: torch.Tensor,
3838
advantages: torch.Tensor,
39-
experiences: Any,
4039
**kwargs,
4140
) -> Tuple[torch.Tensor, Dict]:
42-
negative_approx_kl = logprob - old_logprob
41+
negative_approx_kl = logprob - old_log_probs
4342
ratio = torch.exp(negative_approx_kl)
44-
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
43+
ppo_kl = masked_mean(-negative_approx_kl, response_mask)
4544

4645
pg_losses = -advantages * ratio
4746
pg_losses2 = -advantages * torch.clamp(
4847
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
4948
)
5049

51-
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask)
52-
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
50+
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), response_mask)
51+
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), response_mask)
5352
metrics = {
5453
"pg_clipfrac": pg_clipfrac.detach().item(),
5554
"ppo_kl": ppo_kl.detach().item(),
@@ -62,3 +61,15 @@ def default_args(cls) -> Dict:
6261
return {
6362
"clip_range": 0.2,
6463
}
64+
65+
@property
66+
def select_keys(self) -> List[str]:
67+
return [
68+
"responses",
69+
"input_ids",
70+
"attention_mask",
71+
"position_ids",
72+
"old_log_probs",
73+
"advantages",
74+
"response_mask",
75+
]

0 commit comments

Comments
 (0)