Skip to content
Merged
3 changes: 2 additions & 1 deletion tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _preload_python_lib():
from .builder import BuildConfig, Builder, BuilderConfig, build
from .disaggregated_params import DisaggregatedParams
from .functional import Tensor, constant
from .llmapi import LLM, MultimodalEncoder
from .llmapi import LLM, AsyncLLM, MultimodalEncoder
from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs
from .logger import logger
from .mapping import Mapping
Expand Down Expand Up @@ -136,6 +136,7 @@ def _preload_python_lib():
'quantization',
'tools',
'LLM',
'AsyncLLM',
'MultimodalEncoder',
'LlmArgs',
'TorchLlmArgs',
Expand Down
106 changes: 106 additions & 0 deletions tensorrt_llm/_torch/async_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Any, List, Optional

from ..llmapi.llm import LLM
from ..llmapi.llm_args import RayPlacementConfig


class AsyncLLM(LLM):
"""AsyncLLM is a subclass of LLM that supports asynchronous setup, release and
resume operations that are necessary for RL or agentic scenarios.

Currently, RL APIs are only supported with Ray orchestrator.
"""

def __init__(
self,
placement_groups: Optional[List[Any]] = None,
placement_bundle_indices: Optional[List[List[int]]] = None,
per_worker_gpu_share: Optional[float] = None,
*args,
**kwargs,
):
kwargs["orchestrator_type"] = "ray"
kwargs["ray_placement_config"] = RayPlacementConfig(
defer_workers_init=True,
placement_groups=placement_groups,
placement_bundle_indices=placement_bundle_indices,
per_worker_gpu_share=per_worker_gpu_share,
)

# WAR: RL integration needs to use NCCL AllReduce for TP>1 due to a bug in TRTLLM's AllReduce
# which will cause convergence issue when using multiple rollout instances.
kwargs["allreduce_strategy"] = "NCCL"

if "ray_worker_extension_cls" not in kwargs:
kwargs["ray_worker_extension_cls"] = "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension"

super().__init__(*args, **kwargs)
self._async_initialized = False

async def setup_async(self):
"""Setup the LLM asynchronously."""
if not self._async_initialized:
await self._executor.init_workers_async()
await self._executor.setup_engine_remote_async()
self._async_initialized = True
return self

async def release(self, tags: list[str]):
"""Release the GPU memory used by the LLM asynchronously.

Args:
tags: List of memory tag strings to release (e.g., ["model", "kv_cache"]).
"""
await self.collective_rpc("sleep", args=(tags,))

async def resume(self, tags: list[str]):
"""Resume the GPU memory used by the LLM asynchronously.

Args:
tags: List of memory tag strings to resume (e.g., ["model", "kv_cache"]).
"""
await self.collective_rpc("wakeup", args=(tags,))

async def update_weights(self, weights: dict[str, str]):
"""Update the weights of the LLM asynchronously.


Args:
weights: Dictionary mapping device UUIDs to IPC handles for weight tensors.
"""
await self.collective_rpc("update_weights", args=(weights,))

async def collective_rpc(
self,
method: str,
args: tuple[Any, ...] = (),
kwargs: Optional[dict] = None,
unique_reply_rank: Optional[int] = None,
) -> list[Any]:
"""Execute an asynchronous RPC call on all GPU workers. Currently, this is only supported for RayExecutor.

Args:
method (str): The name of the worker method to execute.
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply.

Returns:
list[Any]: A list of results from each worker.
"""
return await self._executor.collective_rpc_async(
method, args, kwargs, unique_reply_rank=unique_reply_rank
)

def __await__(self):
return self.setup_async().__await__()

def __enter__(self):
raise RuntimeError("Please use 'async with AsyncLLM' instead")

async def __aenter__(self):
await self.setup_async()
return super().__enter__()

async def __aexit__(self, exc_type, exc_val, exc_tb):
return super().__exit__(exc_type, exc_val, exc_tb)
45 changes: 26 additions & 19 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2897,7 +2897,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
new_tokens_host = state.host.new_tokens.flatten().tolist()
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
finish_reasons = state.host.finish_reasons.flatten().tolist()
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
log_probs_host_tensor = state.host.log_probs
cum_log_probs_host = (
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
)
Expand All @@ -2919,24 +2919,31 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)

# Log probs
for request in reqs_with_new_tokens:
if request.py_return_log_probs:
seq_slot = request.py_seq_slot
seq_len = sequence_lengths_host_data[seq_slot]
begin_log_probs_offset = request.prompt_len
current_token = seq_len - request.prompt_len - 1
log_probs = [
{
new_tokens_host[seq_slot]: Logprob(
logprob=log_probs_host[seq_slot][0][
begin_log_probs_offset + current_token
],
rank=1,
)
}
]
cum_log_probs = [cum_log_probs_host[seq_slot]]
request.py_result.append_log_probs([log_probs], cum_log_probs)
if log_probs_host_tensor is not None:
# Log probs
seq_slots = []
seq_lens = []
for request in reqs_with_new_tokens:
if request.py_return_log_probs:
seq_slot = request.py_seq_slot
seq_slots.append(seq_slot)
seq_lens.append(sequence_lengths_host_data[seq_slot] - 1)

log_probs_host = log_probs_host_tensor[seq_slots, 0, seq_lens].tolist()
idx = 0
for request in reqs_with_new_tokens:
if request.py_return_log_probs:
log_probs = [
{
new_tokens_host[seq_slot]: Logprob(
logprob=log_probs_host[idx],
rank=1,
)
}
]
cum_log_probs = [cum_log_probs_host[seq_slot]]
request.py_result.append_log_probs([log_probs], cum_log_probs)
idx += 1

for request in reqs:
request.py_decoding_iter += 1
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/virtual_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ class ExecutorMemoryType(StrEnum):
SPEC_RESOURCES = "spec_resource_manager"
INIT_KV_CACHE = "_no_capture_init_kv_cache"
INIT_EXTRA_RESOURCES = "_no_capture_init_extra_resources"
MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
# MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
MODEL_EXTRA = "model_extra"
EXTRA_RESOURCES = "executor_extra"
KV_CACHE = "kv_cache"
MODEL_ENGINE_MAIN = "model"
Expand Down
Loading