Skip to content

Commit 5829917

Browse files
committed
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/config_refactor
2 parents 9cb0827 + 6c56401 commit 5829917

File tree

15 files changed

+138
-123
lines changed

15 files changed

+138
-123
lines changed

.github/workflows/unittest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- uses: actions/checkout@v4
2020
with:
2121
path: trinity-${{ github.run_id }}
22-
fetch-depth: 0
22+
ref: refs/pull/${{ github.event.issue.number }}/head
2323

2424
- name: Setup docker compose
2525
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ requires-python = ">=3.10"
2121
dependencies = [
2222
"verl==0.3.0.post1",
2323
"ray[default]==2.43.0",
24-
"vllm>=0.8.3",
24+
"vllm>=0.8.5",
2525
"tensordict==0.6.2",
2626
"wandb",
2727
"omegaconf",

tests/buffer/queue_test.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
1-
import unittest
2-
3-
import ray
41
import torch
52

3+
from tests.tools import RayUnittestBase
64
from trinity.buffer.reader.queue_reader import QueueReader
75
from trinity.buffer.writer.queue_writer import QueueWriter
86
from trinity.common.config import BufferConfig, DatasetConfig
97
from trinity.common.constants import AlgorithmType, StorageType
108
from trinity.common.experience import Experience
119

1210

13-
class TestQueueBuffer(unittest.TestCase):
14-
def setUp(self):
15-
ray.init(ignore_reinit_error=True)
16-
11+
class TestQueueBuffer(RayUnittestBase):
1712
def test_queue_buffer(self):
1813
total_num = 8
1914
put_batch_size = 2
2015
read_batch_size = 4
2116
meta = DatasetConfig(
2217
name="test_buffer",
18+
namespace="test_namespace",
2319
algorithm_type=AlgorithmType.PPO,
2420
storage_type=StorageType.QUEUE,
2521
)

tests/common/config_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def test_all_examples_are_valid(self):
2626
example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
2727
for example_name in os.listdir(example_dir):
2828
for filename in os.listdir(os.path.join(example_dir, example_name)):
29-
if filename.endswith(".yaml") and not filename.startswith("train"):
29+
if filename.endswith(".yaml") and not (
30+
filename.startswith("train_") or filename.startswith("verl_")
31+
):
3032
print(f"Checking config: {filename}")
3133
config_path = os.path.join(example_dir, example_name, filename)
3234
try:

tests/common/vllm_test.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,47 +129,76 @@ def test_generate(self):
129129
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
130130

131131

132-
class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase):
132+
class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
133133
def setUp(self):
134134
ray.init(ignore_reinit_error=True)
135135
self.config = get_template_config()
136136
self.config.model.model_path = get_model_path()
137137
self.config.explorer.engine_type = "vllm"
138-
self.config.explorer.engine_num = 1
138+
self.config.explorer.tensor_parallel_size = 1
139+
self.config.explorer.engine_num = 2
139140
self.config.explorer.chat_template = CHAT_TEMPLATE
140141
self.engines = create_rollout_models(self.config)
141142
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
142143

143144

144-
class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase):
145-
@classmethod
146-
def setUpClass(cls):
145+
class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase):
146+
def setUp(self):
147147
ray.init(ignore_reinit_error=True)
148+
self.config = get_template_config()
149+
self.config.model.model_path = get_model_path()
150+
self.config.explorer.engine_type = "vllm_async"
151+
self.config.explorer.engine_num = 2
152+
self.config.explorer.tensor_parallel_size = 1
153+
self.config.explorer.use_v1 = False
154+
self.config.explorer.chat_template = CHAT_TEMPLATE
155+
self.engines = create_rollout_models(self.config)
156+
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
148157

149-
@classmethod
150-
def tearDownClass(cls):
151-
ray.shutdown()
152158

159+
class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase):
153160
def setUp(self):
154161
ray.init(ignore_reinit_error=True)
155162
self.config = get_template_config()
156163
self.config.model.model_path = get_model_path()
157164
self.config.explorer.engine_type = "vllm_async"
158-
self.config.explorer.engine_num = 1
165+
self.config.explorer.engine_num = 2
166+
self.config.explorer.tensor_parallel_size = 2
167+
self.config.explorer.use_v1 = False
159168
self.config.explorer.chat_template = CHAT_TEMPLATE
160169
self.engines = create_rollout_models(self.config)
161170
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
162171

163172

164-
class TestTokenizer(unittest.TestCase):
165-
@classmethod
166-
def setUpClass(cls):
173+
class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase):
174+
def setUp(self):
175+
ray.init(ignore_reinit_error=True)
176+
self.config = get_template_config()
177+
self.config.model.model_path = get_model_path()
178+
self.config.explorer.engine_type = "vllm_async"
179+
self.config.explorer.engine_num = 2
180+
self.config.explorer.tensor_parallel_size = 2
181+
self.config.explorer.use_v1 = True
182+
self.config.explorer.chat_template = CHAT_TEMPLATE
183+
self.engines = create_rollout_models(self.config)
184+
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
185+
186+
187+
class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase):
188+
def setUp(self):
167189
ray.init(ignore_reinit_error=True)
190+
self.config = get_template_config()
191+
self.config.model.model_path = get_model_path()
192+
self.config.explorer.engine_type = "vllm_async"
193+
self.config.explorer.engine_num = 2
194+
self.config.explorer.tensor_parallel_size = 1
195+
self.config.explorer.use_v1 = True
196+
self.config.explorer.chat_template = CHAT_TEMPLATE
197+
self.engines = create_rollout_models(self.config)
198+
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
168199

169-
@classmethod
170-
def tearDownClass(cls):
171-
ray.shutdown()
172200

201+
class TestTokenizer(unittest.TestCase):
173202
def test_assistant_token_mask(self):
174203
messages = [
175204
{"role": "system", "content": "You are a helpful assistant."},

tests/explorer/explorer_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class TestExplorerCountdownEval(BaseExplorerCase):
3737
def test_explorer(self):
3838
self.config.data = get_unittest_dataset_config("countdown")
3939
self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
40+
self.config.explorer.use_v1 = True
4041
self.config.check_and_update()
4142
explore(self.config)
4243
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
@@ -53,6 +54,7 @@ def test_explorer(self):
5354
self.config.data = get_unittest_dataset_config("countdown")
5455
self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
5556
self.config.data.eval_split = None
57+
self.config.explorer.use_v1 = False
5658
self.config.check_and_update()
5759
explore(self.config)
5860
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))

tests/explorer/runner_pool_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def setUp(self):
7070
self.config.buffer.pad_token_id = 0
7171
self.config.buffer.train_dataset = DatasetConfig(
7272
name="test",
73+
namespace="test_runner_pool",
7374
storage_type=StorageType.QUEUE,
7475
algorithm_type=AlgorithmType.PPO,
7576
)

tests/template/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ explorer:
3232
logprobs: 0
3333
backend: nccl
3434
use_ray: false
35+
use_v1: true
3536
trainer:
3637
trainer_type: verl
3738
trainer_config_path: tests/template/verl_config.yaml

tests/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ def setUpClass(cls):
102102

103103
@classmethod
104104
def tearDownClass(cls):
105-
ray.shutdown()
105+
ray.shutdown(_exiting_interpreter=True)

trinity/common/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ class ExplorerConfig:
173173
use_ray: bool = False
174174
gpu_memory_utilization: float = 0.9
175175
enable_chunked_prefil: bool = False
176+
use_v1: bool = True
177+
bundle_indices: str = "" # DO NOT SET this field
176178

177179
# for workflow runner
178180
max_pending_requests: int = 5

0 commit comments

Comments
 (0)