Skip to content

Commit 3651ad8

Browse files
authored
Add more unittest and support Qwen3 (#29)
1 parent 8795012 commit 3651ad8

29 files changed

+578
-53
lines changed

.github/workflows/docker/docker-compose.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ services:
88
- RAY_ADDRESS=auto
99
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
1010
- DATA_ROOT_DIR=/mnt/data
11-
- MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
11+
- MODEL_PATH=/mnt/models/Qwen3-1.7B
12+
- CHECKPOINT_PATH=/mnt/checkpoints
1213
working_dir: /workspace
1314
networks:
1415
- trinity-network
@@ -32,7 +33,8 @@ services:
3233
- HF_ENDPOINT=https://hf-mirror.com
3334
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
3435
- DATA_ROOT_DIR=/mnt/data
35-
- MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
36+
- MODEL_PATH=/mnt/models/Qwen3-1.7B
37+
- CHECKPOINT_PATH=/mnt/checkpoints
3638
working_dir: /workspace
3739
volumes:
3840
- trinity-volume:/mnt

.github/workflows/unittest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
- name: Run unittest
3737
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
3838
run: |
39-
docker compose exec trinity-node-1 pytest tests --ignore=tests/data --ctrf report.json
39+
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
4040
4141
- name: Upload test results
4242
uses: actions/upload-artifact@v4

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ classifiers = [
2020
requires-python = ">=3.10"
2121
dependencies = [
2222
"verl==0.3.0.post1",
23-
"ray==2.43.0",
24-
"vllm==0.8.3",
23+
"ray[default]==2.43.0",
24+
"vllm>=0.8.3",
2525
"tensordict==0.6.2",
2626
"wandb",
2727
"omegaconf",

tests/buffer/queue_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
3+
import ray
4+
import torch
5+
6+
from trinity.buffer.reader.queue_reader import QueueReader
7+
from trinity.buffer.writer.queue_writer import QueueWriter
8+
from trinity.common.config import BufferConfig, DatasetConfig
9+
from trinity.common.constants import AlgorithmType, StorageType
10+
from trinity.common.experience import Experience
11+
12+
13+
class TestQueueBuffer(unittest.TestCase):
14+
def setUp(self):
15+
ray.init(ignore_reinit_error=True)
16+
17+
def test_queue_buffer(self):
18+
total_num = 8
19+
put_batch_size = 2
20+
read_batch_size = 4
21+
meta = DatasetConfig(
22+
name="test_buffer",
23+
algorithm_type=AlgorithmType.PPO,
24+
storage_type=StorageType.QUEUE,
25+
)
26+
config = BufferConfig(
27+
max_retry_times=3,
28+
max_retry_interval=1,
29+
read_batch_size=read_batch_size,
30+
)
31+
writer = QueueWriter(meta, config)
32+
reader = QueueReader(meta, config)
33+
exps = [
34+
Experience(
35+
tokens=torch.tensor([float(j) for j in range(i + 1)]),
36+
prompt_length=i,
37+
reward=float(i),
38+
logprobs=torch.tensor([0.1]),
39+
)
40+
for i in range(1, put_batch_size + 1)
41+
]
42+
for _ in range(total_num // put_batch_size):
43+
writer.write(exps)
44+
writer.finish()
45+
for _ in range(total_num // read_batch_size):
46+
exps = reader.read()
47+
self.assertEqual(len(exps), read_batch_size)
48+
print(f"finish read {read_batch_size} experience")
49+
self.assertRaises(StopIteration, reader.read)

tests/buffer/sql_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
class TestSQLBuffer(unittest.TestCase):
1616
def test_create_sql_buffer(self) -> None:
17-
put_batch_size = 4
18-
read_batch_size = 2
17+
total_num = 8
18+
put_batch_size = 2
19+
read_batch_size = 4
1920
meta = DatasetConfig(
2021
name="test_buffer",
2122
algorithm_type=AlgorithmType.PPO,
@@ -39,7 +40,8 @@ def test_create_sql_buffer(self) -> None:
3940
)
4041
for i in range(1, put_batch_size + 1)
4142
]
42-
sql_writer.write(exps)
43-
for _ in range(put_batch_size // read_batch_size):
43+
for _ in range(total_num // put_batch_size):
44+
sql_writer.write(exps)
45+
for _ in range(total_num // read_batch_size):
4446
exps = sql_reader.read()
4547
self.assertEqual(len(exps), read_batch_size)

tests/common/config_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@
33
import os
44
import unittest
55

6+
from tests.tools import get_template_config
67
from trinity.common.config import load_config
78

8-
config_yaml_path = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")
9-
109

1110
class TestConfig(unittest.TestCase):
1211
def test_load_default_config(self):
13-
config = load_config(config_yaml_path)
14-
print(config.data)
12+
config = get_template_config()
1513
config.check_and_update()
1614
self.assertIsNotNone(config.trainer.trainer_config)
17-
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 4)
15+
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 2)
1816
self.assertEqual(config.trainer.trainer_config.trainer.nnodes, 1)
1917
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project)
2018
self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.monitor.name)

tests/common/vllm_test.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
import torch
66
from transformers import AutoTokenizer
77

8-
from trinity.common.config import load_config
8+
from tests.tools import RayUnittestBase, get_template_config
99
from trinity.common.models import create_rollout_models
1010
from trinity.common.models.model import ModelWrapper
1111
from trinity.common.models.utils import (
1212
tokenize_and_mask_messages_default,
1313
tokenize_and_mask_messages_hf,
1414
)
1515

16-
config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")
17-
1816

1917
def get_model_path() -> str:
2018
path = os.environ.get("MODEL_PATH")
@@ -101,7 +99,12 @@ def test_generate(self):
10199
]
102100
results = self.model_wrapper.chat(messages)
103101
self.assertEqual(len(results), self.config.explorer.repeat_times)
104-
logprobs = self.model_wrapper.logprobs(results[0].tokens)
102+
for result in results:
103+
input_logprobs = result.logprobs[: result.prompt_length]
104+
output_logprobs = result.logprobs[result.prompt_length :]
105+
self.assertTrue(torch.all(input_logprobs == 0))
106+
self.assertTrue(torch.any(output_logprobs != 0))
107+
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
105108
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
106109
messages.append(
107110
{
@@ -126,10 +129,10 @@ def test_generate(self):
126129
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
127130

128131

129-
class TestModelWrapperSync(BaseTestModelWrapper, unittest.TestCase):
132+
class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase):
130133
def setUp(self):
131134
ray.init(ignore_reinit_error=True)
132-
self.config = load_config(config_dir)
135+
self.config = get_template_config()
133136
self.config.model.model_path = get_model_path()
134137
self.config.explorer.engine_type = "vllm"
135138
self.config.explorer.engine_num = 1
@@ -138,10 +141,18 @@ def setUp(self):
138141
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
139142

140143

141-
class TestModelWrapperAsync(BaseTestModelWrapper, unittest.TestCase):
144+
class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase):
145+
@classmethod
146+
def setUpClass(cls):
147+
ray.init(ignore_reinit_error=True)
148+
149+
@classmethod
150+
def tearDownClass(cls):
151+
ray.shutdown()
152+
142153
def setUp(self):
143154
ray.init(ignore_reinit_error=True)
144-
self.config = load_config(config_dir)
155+
self.config = get_template_config()
145156
self.config.model.model_path = get_model_path()
146157
self.config.explorer.engine_type = "vllm_async"
147158
self.config.explorer.engine_num = 1
@@ -151,6 +162,14 @@ def setUp(self):
151162

152163

153164
class TestTokenizer(unittest.TestCase):
165+
@classmethod
166+
def setUpClass(cls):
167+
ray.init(ignore_reinit_error=True)
168+
169+
@classmethod
170+
def tearDownClass(cls):
171+
ray.shutdown()
172+
154173
def test_assistant_token_mask(self):
155174
messages = [
156175
{"role": "system", "content": "You are a helpful assistant."},

tests/explorer/explorer_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Tests for explorer."""
2+
import os
3+
from abc import abstractmethod
4+
from datetime import datetime
5+
6+
from tests.tools import (
7+
RayUnittestBase,
8+
TensorBoardParser,
9+
get_checkpoint_path,
10+
get_model_path,
11+
get_template_config,
12+
get_unittest_dataset_config,
13+
)
14+
from trinity.cli.launcher import explore
15+
from trinity.common.constants import MonitorType
16+
17+
18+
class BaseExplorerCase(RayUnittestBase):
19+
def setUp(self):
20+
self.config = get_template_config()
21+
self.config.model.model_path = get_model_path()
22+
self.config.explorer.engine_type = "vllm_async"
23+
self.config.explorer.repeat_times = 2
24+
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
25+
self.config.monitor.project = "Trinity-unittest"
26+
self.config.model.checkpoint_path = get_checkpoint_path()
27+
self.config.synchronizer.sync_iteration_interval = 2
28+
self.config.explorer.eval_interval = 4
29+
self.config.trainer.eval_interval = 4
30+
31+
@abstractmethod
32+
def test_explorer(self):
33+
"""Test explorer"""
34+
35+
36+
class TestExplorerCountdownEval(BaseExplorerCase):
37+
def test_explorer(self):
38+
self.config.data = get_unittest_dataset_config("countdown")
39+
self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
40+
self.config.check_and_update()
41+
explore(self.config)
42+
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
43+
rollout_metrics = parser.metric_list("rollout")
44+
self.assertTrue(len(rollout_metrics) > 0)
45+
eval_metrics = parser.metric_list("eval")
46+
self.assertTrue(len(eval_metrics) > 0)
47+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
48+
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
49+
50+
51+
class TestExplorerCountdownNoEval(BaseExplorerCase):
52+
def test_explorer(self):
53+
self.config.data = get_unittest_dataset_config("countdown")
54+
self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
55+
self.config.data.eval_split = None
56+
self.config.check_and_update()
57+
explore(self.config)
58+
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
59+
rollout_metrics = parser.metric_list("rollout")
60+
self.assertTrue(len(rollout_metrics) > 0)
61+
eval_metrics = parser.metric_list("eval")
62+
self.assertTrue(len(eval_metrics) == 0)
63+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)

tests/explorer/runner_pool_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run(self) -> List[Experience]:
3131
if "timeout" in self.error_type:
3232
time.sleep(self.seconds)
3333
elif self.error_type == "exception":
34-
raise RuntimeError("Exception occurred")
34+
raise ValueError("Exception occurred")
3535
elif self.error_type == "exit":
3636
exit(1)
3737
return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)]
@@ -107,19 +107,20 @@ def test_runner_pool(self):
107107
tasks=tasks,
108108
)
109109

110-
# The excepted return order is: `exception` -> `timeout_5` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit`
110+
# The excepted return order is: `exception` -> `timeout_2` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit`
111111
# 1. `exception`
112112
st = time.time()
113113
status = pool.get_next_unorder()
114114
et = time.time()
115-
self.assertTrue(et - st < 5)
115+
self.assertTrue(et - st < 2)
116+
print(f"First task use time: {et - st}")
116117
self.assertEqual(len(status), 1)
117118
self.assertFalse(status[0].ok)
118119
# 2. `timeout_2
119120
st = time.time()
120121
status = pool.get_next_unorder()
121122
et = time.time()
122-
self.assertTrue(et - st < 3)
123+
self.assertTrue(et - st > 2)
123124
self.assertEqual(len(status), 1)
124125
self.assertTrue(status[0].ok)
125126
# 3. `success`
Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mode: both
22
data:
33
dataset_path: ''
44
total_epochs: 1
5-
batch_size: 32
5+
batch_size: 4
66
train_split: 'train'
77
eval_split: ''
88
default_workflow_type: ''
@@ -12,34 +12,31 @@ model:
1212
max_prompt_tokens: 2048
1313
max_response_tokens: 2048
1414
checkpoint_path: ''
15-
cluster:
15+
cluster: # 2 for explorer, 2 for trainer
1616
node_num: 1
17-
gpu_per_node: 8
17+
gpu_per_node: 4
1818
buffer:
19-
read_batch_size: 32
2019
max_retry_times: 3
2120
max_retry_interval: 1
2221
explorer:
2322
engine_type: vllm_async
2423
engine_num: 2
25-
runner_num: 16
26-
repeat_times: 2
27-
tensor_parallel_size: 2
24+
runner_num: 4
25+
repeat_times: 1
26+
tensor_parallel_size: 1
2827
enable_prefix_caching: false
2928
enforce_eager: true
3029
dtype: bfloat16
31-
temperature: 0.2
32-
top_p: 0.95
30+
temperature: 1.0
31+
top_p: 1.0
3332
top_k: -1
3433
seed: 42
3534
logprobs: 0
3635
backend: nccl
3736
use_ray: false
38-
max_pending_requests: 5
39-
max_waiting_steps: 1
4037
trainer:
4138
trainer_type: verl
42-
trainer_config_path: tests/common/tmp/template_verl_config.yaml
39+
trainer_config_path: tests/template/verl_config.yaml
4340
monitor:
4441
project: unittest
4542
name: test

0 commit comments

Comments
 (0)