Skip to content

Commit 6c56401

Browse files
authored
Rollout use vLLM V1 engine (#31)
1 parent 514953f commit 6c56401

File tree

14 files changed

+137
-122
lines changed

14 files changed

+137
-122
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ requires-python = ">=3.10"
2121
dependencies = [
2222
"verl==0.3.0.post1",
2323
"ray[default]==2.43.0",
24-
"vllm>=0.8.3",
24+
"vllm>=0.8.5",
2525
"tensordict==0.6.2",
2626
"wandb",
2727
"omegaconf",

tests/buffer/queue_test.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
1-
import unittest
2-
3-
import ray
41
import torch
52

3+
from tests.tools import RayUnittestBase
64
from trinity.buffer.reader.queue_reader import QueueReader
75
from trinity.buffer.writer.queue_writer import QueueWriter
86
from trinity.common.config import BufferConfig, DatasetConfig
97
from trinity.common.constants import AlgorithmType, StorageType
108
from trinity.common.experience import Experience
119

1210

13-
class TestQueueBuffer(unittest.TestCase):
14-
def setUp(self):
15-
ray.init(ignore_reinit_error=True)
16-
11+
class TestQueueBuffer(RayUnittestBase):
1712
def test_queue_buffer(self):
1813
total_num = 8
1914
put_batch_size = 2
2015
read_batch_size = 4
2116
meta = DatasetConfig(
2217
name="test_buffer",
18+
namespace="test_namespace",
2319
algorithm_type=AlgorithmType.PPO,
2420
storage_type=StorageType.QUEUE,
2521
)

tests/common/config_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def test_all_examples_are_valid(self):
2626
example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
2727
for example_name in os.listdir(example_dir):
2828
for filename in os.listdir(os.path.join(example_dir, example_name)):
29-
if filename.endswith(".yaml") and not filename.startswith("train"):
29+
if filename.endswith(".yaml") and not (
30+
filename.startswith("train_") or filename.startswith("verl_")
31+
):
3032
print(f"Checking config: {filename}")
3133
config_path = os.path.join(example_dir, example_name, filename)
3234
try:

tests/common/vllm_test.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,47 +129,76 @@ def test_generate(self):
129129
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
130130

131131

132-
class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase):
132+
class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
133133
def setUp(self):
134134
ray.init(ignore_reinit_error=True)
135135
self.config = get_template_config()
136136
self.config.model.model_path = get_model_path()
137137
self.config.explorer.engine_type = "vllm"
138-
self.config.explorer.engine_num = 1
138+
self.config.explorer.tensor_parallel_size = 1
139+
self.config.explorer.engine_num = 2
139140
self.config.explorer.chat_template = CHAT_TEMPLATE
140141
self.engines = create_rollout_models(self.config)
141142
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
142143

143144

144-
class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase):
145-
@classmethod
146-
def setUpClass(cls):
145+
class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase):
146+
def setUp(self):
147147
ray.init(ignore_reinit_error=True)
148+
self.config = get_template_config()
149+
self.config.model.model_path = get_model_path()
150+
self.config.explorer.engine_type = "vllm_async"
151+
self.config.explorer.engine_num = 2
152+
self.config.explorer.tensor_parallel_size = 1
153+
self.config.explorer.use_v1 = False
154+
self.config.explorer.chat_template = CHAT_TEMPLATE
155+
self.engines = create_rollout_models(self.config)
156+
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
148157

149-
@classmethod
150-
def tearDownClass(cls):
151-
ray.shutdown()
152158

159+
class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase):
153160
def setUp(self):
154161
ray.init(ignore_reinit_error=True)
155162
self.config = get_template_config()
156163
self.config.model.model_path = get_model_path()
157164
self.config.explorer.engine_type = "vllm_async"
158-
self.config.explorer.engine_num = 1
165+
self.config.explorer.engine_num = 2
166+
self.config.explorer.tensor_parallel_size = 2
167+
self.config.explorer.use_v1 = False
159168
self.config.explorer.chat_template = CHAT_TEMPLATE
160169
self.engines = create_rollout_models(self.config)
161170
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
162171

163172

164-
class TestTokenizer(unittest.TestCase):
165-
@classmethod
166-
def setUpClass(cls):
173+
class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase):
174+
def setUp(self):
175+
ray.init(ignore_reinit_error=True)
176+
self.config = get_template_config()
177+
self.config.model.model_path = get_model_path()
178+
self.config.explorer.engine_type = "vllm_async"
179+
self.config.explorer.engine_num = 2
180+
self.config.explorer.tensor_parallel_size = 2
181+
self.config.explorer.use_v1 = True
182+
self.config.explorer.chat_template = CHAT_TEMPLATE
183+
self.engines = create_rollout_models(self.config)
184+
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
185+
186+
187+
class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase):
188+
def setUp(self):
167189
ray.init(ignore_reinit_error=True)
190+
self.config = get_template_config()
191+
self.config.model.model_path = get_model_path()
192+
self.config.explorer.engine_type = "vllm_async"
193+
self.config.explorer.engine_num = 2
194+
self.config.explorer.tensor_parallel_size = 1
195+
self.config.explorer.use_v1 = True
196+
self.config.explorer.chat_template = CHAT_TEMPLATE
197+
self.engines = create_rollout_models(self.config)
198+
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
168199

169-
@classmethod
170-
def tearDownClass(cls):
171-
ray.shutdown()
172200

201+
class TestTokenizer(unittest.TestCase):
173202
def test_assistant_token_mask(self):
174203
messages = [
175204
{"role": "system", "content": "You are a helpful assistant."},

tests/explorer/explorer_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class TestExplorerCountdownEval(BaseExplorerCase):
3737
def test_explorer(self):
3838
self.config.data = get_unittest_dataset_config("countdown")
3939
self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
40+
self.config.explorer.use_v1 = True
4041
self.config.check_and_update()
4142
explore(self.config)
4243
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
@@ -53,6 +54,7 @@ def test_explorer(self):
5354
self.config.data = get_unittest_dataset_config("countdown")
5455
self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
5556
self.config.data.eval_split = None
57+
self.config.explorer.use_v1 = False
5658
self.config.check_and_update()
5759
explore(self.config)
5860
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))

tests/explorer/runner_pool_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def setUp(self):
7070
self.config.buffer.pad_token_id = 0
7171
self.config.buffer.train_dataset = DatasetConfig(
7272
name="test",
73+
namespace="test_runner_pool",
7374
storage_type=StorageType.QUEUE,
7475
algorithm_type=AlgorithmType.PPO,
7576
)

tests/template/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ explorer:
3232
logprobs: 0
3333
backend: nccl
3434
use_ray: false
35+
use_v1: true
3536
trainer:
3637
trainer_type: verl
3738
trainer_config_path: tests/template/verl_config.yaml

tests/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ def setUpClass(cls):
102102

103103
@classmethod
104104
def tearDownClass(cls):
105-
ray.shutdown()
105+
ray.shutdown(_exiting_interpreter=True)

trinity/common/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ class ExplorerConfig:
173173
use_ray: bool = False
174174
gpu_memory_utilization: float = 0.9
175175
enable_chunked_prefil: bool = False
176+
use_v1: bool = True
177+
bundle_indices: str = "" # DO NOT SET this field
176178

177179
# for workflow runner
178180
max_pending_requests: int = 5

trinity/common/models/__init__.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from typing import List
23

34
from trinity.common.config import Config
@@ -12,20 +13,15 @@ def create_rollout_models(
1213
Each model has `tensor_parallel_size` workers.
1314
"""
1415
import ray
15-
import vllm
16-
from ray.util.placement_group import placement_group
16+
from ray.util.placement_group import placement_group, placement_group_table
1717
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
1818

1919
from trinity.common.models.vllm_async_model import vLLMAysncRolloutModel
2020
from trinity.common.models.vllm_model import vLLMRolloutModel
21-
from trinity.utils.log import get_logger
22-
23-
logger = get_logger(__name__)
24-
25-
assert vllm.__version__ >= "0.7.3", "Trinity-RFT only supports vllm >= 0.7.3"
2621

2722
engine_num = config.explorer.engine_num
2823
tensor_parallel_size = config.explorer.tensor_parallel_size
24+
is_multi_process = config.explorer.tensor_parallel_size > 1
2925

3026
vllm_engines = []
3127

@@ -36,28 +32,39 @@ def create_rollout_models(
3632
else:
3733
raise ValueError(f"Unknown engine type: {config.explorer.engine_type}")
3834

39-
bundles = [{"GPU": tensor_parallel_size, "CPU": 1} for _ in range(engine_num)]
40-
pg = placement_group(bundles)
35+
bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)]
36+
pg = placement_group(bundles, strategy="PACK")
4137
ray.get(pg.ready())
4238

43-
for i in range(engine_num):
44-
logger.info(f"Creating vLLM engine {i}")
45-
scheduling_strategy = None
39+
vllm_engines = []
4640

47-
scheduling_strategy = PlacementGroupSchedulingStrategy(
48-
placement_group=pg,
49-
placement_group_capture_child_tasks=True,
50-
placement_group_bundle_index=i,
51-
)
41+
# to address https://github.com/ray-project/ray/issues/51117
42+
# aggregate bundles belonging to the same node
43+
bundle_node_map = placement_group_table(pg)["bundles_to_node_id"]
44+
node_bundle_map = defaultdict(list)
45+
for bundle_id, node_id in bundle_node_map.items():
46+
node_bundle_map[node_id].append(bundle_id)
5247

53-
vllm_engines.append(
54-
engine_cls.options( # type: ignore [attr-defined]
55-
num_cpus=0,
56-
num_gpus=tensor_parallel_size,
57-
scheduling_strategy=scheduling_strategy,
58-
).remote(
59-
config=config,
60-
)
48+
for node_id, bundle_ids in node_bundle_map.items():
49+
assert len(bundle_ids) % tensor_parallel_size == 0, (
50+
f"Node {node_id} has {len(bundle_ids)} bundles, "
51+
f"which is not divisible by tensor_parallel_size({tensor_parallel_size})"
6152
)
62-
53+
for i in range(len(bundle_ids) // tensor_parallel_size):
54+
bundles_for_engine = bundle_ids[
55+
i * tensor_parallel_size : (i + 1) * tensor_parallel_size
56+
]
57+
config.explorer.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine])
58+
vllm_engines.append(
59+
engine_cls.options(
60+
num_cpus=0,
61+
num_gpus=0 if is_multi_process else 1,
62+
scheduling_strategy=PlacementGroupSchedulingStrategy(
63+
placement_group=pg,
64+
placement_group_bundle_index=bundles_for_engine[0],
65+
),
66+
).remote(
67+
config=config,
68+
)
69+
)
6370
return vllm_engines

0 commit comments

Comments
 (0)