Skip to content

Commit 4c12852

Browse files
committed
placement, asyncllm, and basic tests
1 parent 20ad93a commit 4c12852

File tree

15 files changed

+348
-1937
lines changed

15 files changed

+348
-1937
lines changed

examples/llm-api/rl_integration_test.py

Lines changed: 0 additions & 618 deletions
This file was deleted.

examples/llm-api/rl_integration_test_async.py

Lines changed: 0 additions & 647 deletions
This file was deleted.

examples/rl/rl_integration_test.py

Lines changed: 0 additions & 618 deletions
This file was deleted.

tensorrt_llm/_torch/async_llm.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any, Optional
2+
3+
from ..llmapi.llm import LLM
4+
from .virtual_memory import ExecutorMemoryType
5+
6+
7+
class AsyncLLM(LLM):
8+
"""AsyncLLM is a subclass of LLM that supports asynchronous setup, release and
9+
resume operations that are necessary for RL or agentic scenarios.
10+
"""
11+
12+
def __init__(self, *args, **kwargs):
13+
# AsyncLLM is only supported with Ray orchestrator now.
14+
kwargs["orchestrator_type"] = "ray"
15+
if 'ray_worker_extension_cls' not in kwargs:
16+
kwargs['ray_worker_extension_cls'] = 'tensorrt_llm.llmapi.rlhf_utils.WorkerExtension'
17+
super().__init__(*args, **kwargs)
18+
19+
async def setup_async(self):
20+
"""Setup the LLM asynchronously."""
21+
await self._executor.init_workers_async()
22+
23+
async def release(self, tags: list[str]):
24+
"""Release the GPU memory used by the LLM asynchronously.
25+
26+
Args:
27+
tags: List of memory tag strings to release (e.g., ["model", "kv_cache"]).
28+
"""
29+
await self.collective_rpc("sleep", args=(tags,))
30+
31+
async def resume(self, tags: list[str]):
32+
"""Resume the GPU memory used by the LLM asynchronously.
33+
34+
Args:
35+
tags: List of memory tag strings to resume (e.g., ["model", "kv_cache"]).
36+
"""
37+
await self.collective_rpc("wakeup", args=(tags,))
38+
39+
async def update_weights(self, weights: dict[str, str]):
40+
"""Update the weights of the LLM asynchronously.
41+
42+
43+
Args:
44+
weights: Dictionary mapping device UUIDs to IPC handles for weight tensors.
45+
"""
46+
await self.collective_rpc("update_weights", args=(weights,))
47+
48+
async def collective_rpc(
49+
self,
50+
method: str,
51+
args: tuple[Any, ...] = (),
52+
kwargs: Optional[dict] = None,
53+
unique_reply_rank: Optional[int] = None,
54+
) -> list[Any]:
55+
"""Execute an asynchronous RPC call on all GPU workers. Currently, this is only supported for RayExecutor.
56+
57+
Args:
58+
method (str): The name of the worker method to execute.
59+
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
60+
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
61+
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply.
62+
63+
Returns:
64+
list[Any]: A list of results from each worker.
65+
"""
66+
return await self._executor.collective_rpc_async(
67+
method, args, kwargs, unique_reply_rank=unique_reply_rank
68+
)

tensorrt_llm/_torch/virtual_memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ class ExecutorMemoryType(StrEnum):
7474
SPEC_RESOURCES = "spec_resource_manager"
7575
INIT_KV_CACHE = "_no_capture_init_kv_cache"
7676
INIT_EXTRA_RESOURCES = "_no_capture_init_extra_resources"
77-
MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
77+
# MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
78+
MODEL_EXTRA = "model_extra"
7879
EXTRA_RESOURCES = "executor_extra"
7980
KV_CACHE = "kv_cache"
8081
MODEL_ENGINE_MAIN = "model"

tensorrt_llm/executor/ray_executor.py

Lines changed: 84 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import os
21
import asyncio
2+
import os
33
from typing import Any, Dict, List, Optional, Tuple
44

55
try:
@@ -8,8 +8,7 @@
88
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
99
raise
1010

11-
from ray.util.placement_group import (PlacementGroup,
12-
PlacementGroupSchedulingStrategy,
11+
from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
1312
get_current_placement_group,
1413
placement_group)
1514

@@ -79,15 +78,15 @@ def __init__(self,
7978
self.master_address = ray.util.get_node_ip_address()
8079
self.master_port = get_free_port()
8180

82-
self.worker_kwargs = dict(**worker_kwargs,
83-
postproc_worker_config=postproc_worker_config,
84-
is_llm_executor=is_llm_executor)
85-
if not has_event_loop():
86-
self.init_workers_sync()
81+
self.worker_kwargs = dict(
82+
**worker_kwargs,
83+
postproc_worker_config=postproc_worker_config,
84+
is_llm_executor=is_llm_executor)
8785

8886
self.init_rpc_executor()
8987
worker_kwargs['rpc_addr'] = self.rpc_addr
90-
self.create_workers(RayGPUWorker, worker_kwargs)
88+
if not has_event_loop():
89+
self.init_workers_sync()
9190
self.setup_engine_remote()
9291
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
9392
thread_name="ray_executor_main_loop")
@@ -99,9 +98,13 @@ def __init__(self,
9998
raise e
10099

101100
def create_workers(self, worker_cls, worker_kwargs):
101+
llm_args = worker_kwargs.get("llm_args")
102+
102103
# When set to be a fraction, it allows Ray to schedule
103104
# multiple actors on a single GPU for colocate use cases.
104-
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
105+
num_gpus = (llm_args.per_worker_gpu_share if llm_args
106+
and llm_args.per_worker_gpu_share is not None else float(
107+
os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0")))
105108
logger.debug(f"{num_gpus=} for each worker.")
106109

107110
runtime_env = ray.runtime_env.RuntimeEnv()
@@ -112,42 +115,40 @@ def create_workers(self, worker_cls, worker_kwargs):
112115
"MASTER_PORT": str(self.master_port)
113116
})
114117

115-
self.placement_group, self.bundle_indices = self._get_placement_group(
116-
tp_size=self.tp_size)
118+
placement_groups, self.bundle_indices = self._get_placement_group(
119+
tp_size=self.tp_size, worker_kwargs=worker_kwargs)
117120

118-
self.workers = [
119-
RayWorkerWrapper.options(
121+
if isinstance(placement_groups, list):
122+
self.placement_group = None
123+
else:
124+
self.placement_group = placement_groups
125+
126+
self.workers = []
127+
for rank in range(self.world_size):
128+
pg = placement_groups[rank] if isinstance(
129+
placement_groups, list) else placement_groups
130+
worker = RayWorkerWrapper.options(
120131
num_gpus=num_gpus,
121-
runtime_env=runtime_env, # per-actor env
132+
runtime_env=runtime_env,
122133
scheduling_strategy=PlacementGroupSchedulingStrategy(
123-
placement_group=self.placement_group,
134+
placement_group=pg,
124135
placement_group_bundle_index=self.bundle_indices[rank],
125136
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
126-
for rank in range(self.world_size)
127-
]
137+
self.workers.append(worker)
128138

129139
def init_workers_sync(self):
130140
self.create_workers(RayGPUWorker, self.worker_kwargs)
131141
try:
132-
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
142+
ray.get(self._get_worker_ready_futures())
133143
except ray.exceptions.ActorDiedError as e:
134-
if "The actor died because of an error raised in its creation task" in str(
135-
e):
136-
raise RuntimeError(
137-
"RayGPUWorker died during initialization") from e
138-
raise
144+
raise RuntimeError("RayGPUWorker died during initialization") from e
139145

140146
async def init_workers_async(self):
141147
self.create_workers(RayGPUWorker, self.worker_kwargs)
142148
try:
143-
await asyncio.gather(*[worker.__ray_ready__.remote() for worker in self.workers])
149+
await asyncio.gather(*self._get_worker_ready_futures())
144150
except ray.exceptions.ActorDiedError as e:
145-
if "The actor died because of an error raised in its creation task" in str(
146-
e):
147-
raise RuntimeError(
148-
"RayGPUWorker died during initialization") from e
149-
raise
150-
151+
raise RuntimeError("RayGPUWorker died during initialization") from e
151152

152153
@unwrap_ray_errors()
153154
def call_all_ray_workers(self, func: str, leader_only: bool,
@@ -187,6 +188,20 @@ def collective_rpc(self,
187188
**kwargs))
188189
return refs if non_block else ray.get(refs)
189190

191+
@unwrap_ray_errors()
192+
async def collective_rpc_async(
193+
self,
194+
method: str,
195+
args: tuple = (),
196+
kwargs: Optional[dict] = None,
197+
unique_reply_rank: Optional[int] = None) -> list[Any]:
198+
refs = self.collective_rpc(method,
199+
args,
200+
kwargs,
201+
non_block=True,
202+
unique_reply_rank=unique_reply_rank)
203+
return await asyncio.gather(*refs)
204+
190205
def submit(self, request: "GenerationRequest") -> "GenerationResult":
191206
"""
192207
Low-level API to the executor. Return a "future" GenerationResult
@@ -281,15 +296,51 @@ def shutdown(self):
281296
logger.debug("Shutting down Ray cluster")
282297
ray.shutdown()
283298

284-
def _get_placement_group(self,
285-
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
299+
def _get_worker_ready_futures(self):
300+
return [worker.__ray_ready__.remote() for worker in self.workers]
301+
302+
def _get_placement_group(
303+
self,
304+
tp_size: int,
305+
worker_kwargs: Dict = None) -> Tuple[Any, List[int]]:
286306
"""
287307
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
288308
or create a default PACK placement group where each bundle has tp_size GPUs.
289309
- When tp_size ≤ GPUs per node, keep one TP group per node.
290310
- When tp_size > GPUs per node, allow a TP group span nodes.
291311
- rank 0 must be put on the driver node
312+
313+
Returns:
314+
Tuple of (placement_group(s), bundle_indices)
315+
- placement_group(s) can be a single PlacementGroup or a List[PlacementGroup]
316+
- bundle_indices is always a List[int]
292317
"""
318+
llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None
319+
320+
if llm_args and hasattr(
321+
llm_args,
322+
'placement_groups') and llm_args.placement_groups is not None:
323+
total_workers = sum(
324+
len(indices) for indices in llm_args.placement_bundle_indices)
325+
if total_workers != self.world_size:
326+
raise ValueError(
327+
f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})"
328+
)
329+
330+
logger.info(
331+
f"Creating {self.world_size} workers with external placement groups"
332+
)
333+
334+
flat_pgs = []
335+
flat_indices = []
336+
for pg, indices in zip(llm_args.placement_groups,
337+
llm_args.placement_bundle_indices):
338+
for idx in indices:
339+
flat_pgs.append(pg)
340+
flat_indices.append(idx)
341+
342+
return flat_pgs, flat_indices
343+
293344
bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)
294345

295346
if bundle_indices:

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import importlib
23
import os
34
from pathlib import Path
@@ -216,6 +217,8 @@ def sleep(self, sleep_tags: List[str]):
216217
torch.cuda.synchronize()
217218
release_with_tag(*tags)
218219
torch.cuda.synchronize()
220+
gc.collect()
221+
torch.cuda.empty_cache()
219222
except Exception as e:
220223
logger.error(f"Encountered an error in sleep: {e}")
221224
raise e

tensorrt_llm/llmapi/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from ..executor import CompletionOutput, LoRARequest, RequestError
33
from ..sampling_params import GuidedDecodingParams, SamplingParams
44
from .build_cache import BuildCacheConfig
5-
from .llm import LLM, AsyncLLM, RequestOutput
5+
from .llm import LLM, RequestOutput
6+
from .._torch.async_llm import AsyncLLM
67
# yapf: disable
78
from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
89
CacheTransceiverConfig, CalibConfig,

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
4949
from .utils import (append_docstring, exception_handler, get_device_count,
5050
logger_debug, set_api_status)
51+
from ray.util.placement_group import PlacementGroup, placement_group
5152

5253

5354
class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
@@ -1149,10 +1150,3 @@ def __init__(self,
11491150
11501151
Parameters:
11511152
""" + TORCH_LLM_DOCSTRING
1152-
1153-
class AsyncLLM(LLM):
1154-
def __init__(self, *args, **kwargs):
1155-
super().__init__(*args, **kwargs)
1156-
1157-
async def async_init_phase(self):
1158-
await self._executor.init_workers_async()

0 commit comments

Comments
 (0)