Skip to content

Commit 565208b

Browse files
committed
Add multi instance test
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
1 parent 01700aa commit 565208b

File tree

2 files changed

+115
-12
lines changed

2 files changed

+115
-12
lines changed

examples/ray_orchestrator/rl_perf_repro.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
@ray.remote
22-
class trtllm_instance:
22+
class TRTLLMInstance:
2323
def __init__(self, async_llm_kwargs: dict, sampling_kwargs: dict):
2424
self.async_llm_kwargs = async_llm_kwargs
2525
self.sampling_kwargs = sampling_kwargs
@@ -62,7 +62,7 @@ async def init_llm(self):
6262
)
6363

6464
async def generate(self, prompt: list[int]):
65-
"""Generate for a single prompt"""
65+
"""Generate for a single prompt."""
6666
outputs = await self.llm.generate_async(inputs=prompt, sampling_params=self.sampling_params)
6767
token_ids = outputs.outputs[0].token_ids
6868
log_probs = None
@@ -111,7 +111,7 @@ async def setup_rl_llm(args):
111111
llm_instances = []
112112
for i in range(num_instances):
113113
llm_instances.append(
114-
trtllm_instance.options(
114+
TRTLLMInstance.options(
115115
num_cpus=0,
116116
num_gpus=0,
117117
runtime_env=runtime_env,
@@ -130,7 +130,7 @@ async def setup_rl_llm(args):
130130
"free_gpu_memory_fraction": args.kv_cache_fraction,
131131
},
132132
"cuda_graph_config": {
133-
"enable_padding": args.enable_padding,
133+
"enable_padding": args.enable_cuda_graph_padding,
134134
"batch_sizes": args.batch_sizes,
135135
"max_batch_size": 0 if args.batch_sizes else args.max_batch_size,
136136
},
@@ -171,7 +171,7 @@ async def setup_rl_llm(args):
171171

172172
# Helper function to wrap Ray remote call as async coroutine
173173
async def generate_single_prompt(instance, prompt):
174-
"""Generate a single prompt asynchronously"""
174+
"""Generate a single prompt asynchronously."""
175175
object_ref = instance.generate.remote(prompt=prompt)
176176
result = await asyncio.to_thread(ray.get, object_ref)
177177
return result
@@ -182,7 +182,7 @@ async def generate_single_prompt(instance, prompt):
182182
for idx, prompt in enumerate(prompts)
183183
]
184184

185-
results = await asyncio.gather(*tasks)
185+
await asyncio.gather(*tasks)
186186
end_time = time.time()
187187

188188
print(f"Time taken: {end_time - start_time:.2f} seconds")
@@ -216,7 +216,10 @@ def add_rl_llm_args(parser):
216216
help="Sampler type.",
217217
)
218218
parser.add_argument(
219-
"--trust_remote_code", type=bool, default=True, help="Whether to trust remote code."
219+
"--trust_remote_code",
220+
action="store_true",
221+
default=False,
222+
help="Whether to trust remote code.",
220223
)
221224

222225
# KV Cache Config parameters
@@ -228,16 +231,16 @@ def add_rl_llm_args(parser):
228231
)
229232
parser.add_argument(
230233
"--enable_block_reuse",
231-
type=bool,
232-
default=True,
234+
action="store_true",
235+
default=False,
233236
help="Whether to enable block reuse for KV cache.",
234237
)
235238

236239
# Cuda Graph Config parameters
237240
parser.add_argument(
238-
"--enable_padding",
239-
type=bool,
240-
default=True,
241+
"--enable_cuda_graph_padding",
242+
action="store_true",
243+
default=False,
241244
help="Whether to enable padding for CUDA graphs.",
242245
)
243246
parser.add_argument(
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
3+
import pytest
4+
import ray
5+
import torch
6+
from ray.util.placement_group import (
7+
PlacementGroupSchedulingStrategy,
8+
placement_group,
9+
remove_placement_group,
10+
)
11+
from utils.llm_data import llm_models_root
12+
13+
from tensorrt_llm import AsyncLLM
14+
from tensorrt_llm.llmapi import KvCacheConfig
15+
16+
17+
@ray.remote
18+
class TRTLLMInstance:
19+
def __init__(self, async_llm_kwargs: dict):
20+
self.llm = AsyncLLM(
21+
model=async_llm_kwargs["model"],
22+
backend="pytorch",
23+
orchestrator_type=async_llm_kwargs["orchestrator_type"],
24+
kv_cache_config=KvCacheConfig(**async_llm_kwargs["kv_cache_config"]),
25+
tensor_parallel_size=async_llm_kwargs["tensor_parallel_size"],
26+
placement_groups=async_llm_kwargs["placement_groups"],
27+
placement_bundle_indices=async_llm_kwargs["placement_bundle_indices"],
28+
per_worker_gpu_share=async_llm_kwargs["per_worker_gpu_share"],
29+
)
30+
31+
async def init_llm(self):
32+
await self.llm.setup_async()
33+
34+
35+
@pytest.mark.gpu8
36+
@pytest.mark.parametrize(
37+
"tp_size, num_instances", [(2, 4), (1, 8)], ids=["tp2_instances4", "tp1_instances8"]
38+
)
39+
def test_multi_instance(monkeypatch, tp_size, num_instances):
40+
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
41+
42+
num_gpus = tp_size * num_instances
43+
available_gpus = torch.cuda.device_count()
44+
if num_gpus > 8:
45+
raise ValueError(
46+
f"Number of GPUs ({num_gpus}) is greater than 8. This script only supports single node."
47+
)
48+
if available_gpus < num_gpus:
49+
raise ValueError(
50+
f"Number of GPUs ({available_gpus}) is less than number of GPUs required ({num_gpus})."
51+
)
52+
53+
excution_times = 5
54+
for i in range(excution_times):
55+
pg = None
56+
try:
57+
ray.init(address="local")
58+
pg = placement_group(
59+
[{"GPU": 1, "CPU": 2} for _ in range(num_gpus)], strategy="STRICT_PACK"
60+
)
61+
62+
ray.get(pg.ready())
63+
64+
placement_group_list = [[pg] for _ in range(num_instances)]
65+
placement_bundle_indices_list = [
66+
[list(range(i * tp_size, (i + 1) * tp_size))] for i in range(num_instances)
67+
]
68+
69+
llm_instances = []
70+
for i in range(num_instances):
71+
llm_instances.append(
72+
TRTLLMInstance.options(
73+
num_cpus=0,
74+
num_gpus=0,
75+
scheduling_strategy=PlacementGroupSchedulingStrategy(
76+
placement_group=pg,
77+
placement_group_capture_child_tasks=True,
78+
),
79+
).remote(
80+
async_llm_kwargs={
81+
"model": os.path.join(
82+
llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"
83+
),
84+
"kv_cache_config": {
85+
"free_gpu_memory_fraction": 0.1,
86+
},
87+
"tensor_parallel_size": tp_size,
88+
"orchestrator_type": "ray",
89+
"placement_groups": placement_group_list[i],
90+
"placement_bundle_indices": placement_bundle_indices_list[i],
91+
"per_worker_gpu_share": 0.5,
92+
}
93+
)
94+
)
95+
ray.get([llm.__ray_ready__.remote() for llm in llm_instances])
96+
ray.get([llm.init_llm.remote() for llm in llm_instances])
97+
finally:
98+
if pg is not None:
99+
remove_placement_group(pg)
100+
ray.shutdown()

0 commit comments

Comments
 (0)