Skip to content

Commit e0e6664

Browse files
hchingsdavidmlwtongyuantongyuSuperjomn
authored andcommitted
[TRTLLM-9736][feat] AsyncLLM and verl integ (NVIDIA#9353)
Signed-off-by: Liwei Ma <[email protected]> Signed-off-by: Yuan Tong <[email protected]> Signed-off-by: Superjomn <[email protected]> Signed-off-by: Erin Ho <[email protected]> Co-authored-by: Liwei Ma <[email protected]> Co-authored-by: Yuan Tong <[email protected]> Co-authored-by: Superjomn <[email protected]>
1 parent c0d5c78 commit e0e6664

File tree

17 files changed

+629
-70
lines changed

17 files changed

+629
-70
lines changed

tensorrt_llm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _preload_python_lib():
8484
from .builder import BuildConfig, Builder, BuilderConfig, build
8585
from .disaggregated_params import DisaggregatedParams
8686
from .functional import Tensor, constant
87-
from .llmapi import LLM, MultimodalEncoder
87+
from .llmapi import LLM, AsyncLLM, MultimodalEncoder
8888
from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs
8989
from .logger import logger
9090
from .mapping import Mapping
@@ -136,6 +136,7 @@ def _preload_python_lib():
136136
'quantization',
137137
'tools',
138138
'LLM',
139+
'AsyncLLM',
139140
'MultimodalEncoder',
140141
'LlmArgs',
141142
'TorchLlmArgs',

tensorrt_llm/_torch/async_llm.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from typing import Any, List, Optional
2+
3+
from ..llmapi.llm import LLM
4+
from ..llmapi.llm_args import RayPlacementConfig
5+
6+
7+
class AsyncLLM(LLM):
8+
"""AsyncLLM is a subclass of LLM that supports asynchronous setup, release and
9+
resume operations that are necessary for RL or agentic scenarios.
10+
11+
Currently, RL APIs are only supported with Ray orchestrator.
12+
"""
13+
14+
def __init__(
15+
self,
16+
placement_groups: Optional[List[Any]] = None,
17+
placement_bundle_indices: Optional[List[List[int]]] = None,
18+
per_worker_gpu_share: Optional[float] = None,
19+
*args,
20+
**kwargs,
21+
):
22+
kwargs["orchestrator_type"] = "ray"
23+
kwargs["ray_placement_config"] = RayPlacementConfig(
24+
defer_workers_init=True,
25+
placement_groups=placement_groups,
26+
placement_bundle_indices=placement_bundle_indices,
27+
per_worker_gpu_share=per_worker_gpu_share,
28+
)
29+
30+
# WAR: RL integration needs to use NCCL AllReduce for TP>1 due to a bug in TRTLLM's AllReduce
31+
# which will cause convergence issue when using multiple rollout instances.
32+
kwargs["allreduce_strategy"] = "NCCL"
33+
34+
if "ray_worker_extension_cls" not in kwargs:
35+
kwargs["ray_worker_extension_cls"] = "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension"
36+
37+
super().__init__(*args, **kwargs)
38+
self._async_initialized = False
39+
40+
async def setup_async(self):
41+
"""Setup the LLM asynchronously."""
42+
if not self._async_initialized:
43+
await self._executor.init_workers_async()
44+
await self._executor.setup_engine_remote_async()
45+
self._async_initialized = True
46+
return self
47+
48+
async def release(self, tags: list[str]):
49+
"""Release the GPU memory used by the LLM asynchronously.
50+
51+
Args:
52+
tags: List of memory tag strings to release (e.g., ["model", "kv_cache"]).
53+
"""
54+
await self.collective_rpc("sleep", args=(tags,))
55+
56+
async def resume(self, tags: list[str]):
57+
"""Resume the GPU memory used by the LLM asynchronously.
58+
59+
Args:
60+
tags: List of memory tag strings to resume (e.g., ["model", "kv_cache"]).
61+
"""
62+
await self.collective_rpc("wakeup", args=(tags,))
63+
64+
async def update_weights(self, weights: dict[str, str]):
65+
"""Update the weights of the LLM asynchronously.
66+
67+
68+
Args:
69+
weights: Dictionary mapping device UUIDs to IPC handles for weight tensors.
70+
"""
71+
await self.collective_rpc("update_weights", args=(weights,))
72+
73+
async def collective_rpc(
74+
self,
75+
method: str,
76+
args: tuple[Any, ...] = (),
77+
kwargs: Optional[dict] = None,
78+
unique_reply_rank: Optional[int] = None,
79+
) -> list[Any]:
80+
"""Execute an asynchronous RPC call on all GPU workers. Currently, this is only supported for RayExecutor.
81+
82+
Args:
83+
method (str): The name of the worker method to execute.
84+
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
85+
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
86+
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply.
87+
88+
Returns:
89+
list[Any]: A list of results from each worker.
90+
"""
91+
return await self._executor.collective_rpc_async(
92+
method, args, kwargs, unique_reply_rank=unique_reply_rank
93+
)
94+
95+
def __await__(self):
96+
return self.setup_async().__await__()
97+
98+
def __enter__(self):
99+
raise RuntimeError("Please use 'async with AsyncLLM' instead")
100+
101+
async def __aenter__(self):
102+
await self.setup_async()
103+
return super().__enter__()
104+
105+
async def __aexit__(self, exc_type, exc_val, exc_tb):
106+
return super().__exit__(exc_type, exc_val, exc_tb)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3010,7 +3010,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
30103010
new_tokens_host = state.host.new_tokens.flatten().tolist()
30113011
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
30123012
finish_reasons = state.host.finish_reasons.flatten().tolist()
3013-
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
3013+
log_probs_host_tensor = state.host.log_probs
30143014
cum_log_probs_host = (
30153015
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
30163016
)
@@ -3032,24 +3032,31 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
30323032
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
30333033

30343034
# Log probs
3035-
for request in reqs_with_new_tokens:
3036-
if request.py_return_log_probs:
3037-
seq_slot = request.py_seq_slot
3038-
seq_len = sequence_lengths_host_data[seq_slot]
3039-
begin_log_probs_offset = request.prompt_len
3040-
current_token = seq_len - request.prompt_len - 1
3041-
log_probs = [
3042-
{
3043-
new_tokens_host[seq_slot]: Logprob(
3044-
logprob=log_probs_host[seq_slot][0][
3045-
begin_log_probs_offset + current_token
3046-
],
3047-
rank=1,
3048-
)
3049-
}
3050-
]
3051-
cum_log_probs = [cum_log_probs_host[seq_slot]]
3052-
request.py_result.append_log_probs([log_probs], cum_log_probs)
3035+
if log_probs_host_tensor is not None:
3036+
# Log probs
3037+
seq_slots = []
3038+
seq_lens = []
3039+
for request in reqs_with_new_tokens:
3040+
if request.py_return_log_probs:
3041+
seq_slot = request.py_seq_slot
3042+
seq_slots.append(seq_slot)
3043+
seq_lens.append(sequence_lengths_host_data[seq_slot] - 1)
3044+
3045+
log_probs_host = log_probs_host_tensor[seq_slots, 0, seq_lens].tolist()
3046+
idx = 0
3047+
for request in reqs_with_new_tokens:
3048+
if request.py_return_log_probs:
3049+
log_probs = [
3050+
{
3051+
new_tokens_host[seq_slot]: Logprob(
3052+
logprob=log_probs_host[idx],
3053+
rank=1,
3054+
)
3055+
}
3056+
]
3057+
cum_log_probs = [cum_log_probs_host[seq_slot]]
3058+
request.py_result.append_log_probs([log_probs], cum_log_probs)
3059+
idx += 1
30533060

30543061
for request in reqs:
30553062
request.py_decoding_iter += 1

tensorrt_llm/_torch/virtual_memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ class ExecutorMemoryType(StrEnum):
7474
SPEC_RESOURCES = "spec_resource_manager"
7575
INIT_KV_CACHE = "_no_capture_init_kv_cache"
7676
INIT_EXTRA_RESOURCES = "_no_capture_init_extra_resources"
77-
MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
77+
# MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
78+
MODEL_EXTRA = "model_extra"
7879
EXTRA_RESOURCES = "executor_extra"
7980
KV_CACHE = "kv_cache"
8081
MODEL_ENGINE_MAIN = "model"

0 commit comments

Comments
 (0)