Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
647 changes: 647 additions & 0 deletions examples/llm-api/rl_integration_test_async.py

Large diffs are not rendered by default.

722 changes: 722 additions & 0 deletions examples/llm-api/rl_integration_test_async_pg.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _preload_python_lib():
from .builder import BuildConfig, Builder, BuilderConfig, build
from .disaggregated_params import DisaggregatedParams
from .functional import Tensor, constant
from .llmapi import LLM, MultimodalEncoder
from .llmapi import LLM, AsyncLLM, MultimodalEncoder
from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs
from .logger import logger
from .mapping import Mapping
Expand Down Expand Up @@ -136,6 +136,7 @@ def _preload_python_lib():
'quantization',
'tools',
'LLM',
'AsyncLLM',
'MultimodalEncoder',
'LlmArgs',
'TorchLlmArgs',
Expand Down
55 changes: 55 additions & 0 deletions tensorrt_llm/_torch/async_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any, Optional

from .llm import LLM
from .virtual_memory import ExecutorMemoryType


class AsyncLLM(LLM):
"""AsyncLLM is a subclass of LLM that supports asynchronous setup, release and
resume operations that are necessary for RL or agentic scenarios.
"""

def __init__(self, *args, **kwargs):
# AsyncLLM is only supported with Ray orchestrator now.
kwargs["orchestrator_type"] = "ray"
super().__init__(*args, **kwargs)

async def setup_async(self):
"""Setup the LLM asynchronously."""
await self._executor.init_workers_async()

async def release_memory_async(self):
"""Release the GPU memory used by the LLM asynchronously."""
tags = [tag.value for tag in ExecutorMemoryType]
await self.collective_call_async("sleep", args=(tags,))

async def resume_memory_async(self):
"""Resume the GPU memory used by the LLM asynchronously."""
tags = [tag.value for tag in ExecutorMemoryType]
await self.collective_call_async("wakeup", args=(tags,))

async def update_weights_async(self, weights: dict[str, str]):
"""Update the weights of the LLM asynchronously."""
await self.collective_call_async("update_weights", args=(weights,))

async def collective_call_async(
self,
method: str,
args: tuple[Any, ...] = (),
kwargs: Optional[dict] = None,
unique_reply_rank: Optional[int] = None,
) -> list[Any]:
"""Execute an asynchronous RPC call on all GPU workers. Currently, this is only supported for RayExecutor.

Args:
method (str): The name of the worker method to execute.
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply.

Returns:
list[Any]: A list of results from each worker.
"""
return await self._executor.collective_rpc_async(
method, args, kwargs, unique_reply_rank=unique_reply_rank
)
121 changes: 96 additions & 25 deletions tensorrt_llm/executor/ray_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -7,8 +8,7 @@
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
raise

from ray.util.placement_group import (PlacementGroup,
PlacementGroupSchedulingStrategy,
from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
get_current_placement_group,
placement_group)

Expand All @@ -23,6 +23,7 @@
from .request import GenerationRequest
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
from .rpc_proxy import RpcExecutorMixin
from .utils import has_event_loop

__all__ = [
"RayExecutor",
Expand Down Expand Up @@ -78,14 +79,16 @@ def __init__(self,
self.master_port = get_free_port()
self.use_rpc = ray_use_rpc()

worker_kwargs = dict(**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
self.worker_kwargs = dict(
**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)

if self.use_rpc:
self.init_rpc_executor()
worker_kwargs['rpc_addr'] = self.rpc_addr
self.create_workers(RayGPUWorker, worker_kwargs)
self.worker_kwargs['rpc_addr'] = self.rpc_addr
if not has_event_loop():
self.init_workers_sync()
self.setup_engine_remote()
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
Expand All @@ -107,17 +110,22 @@ def __init__(self,
self.response_sync_queue)
self.response_queue.warmup.remote()
self.response_sync_queue.warmup.remote()
self.create_workers(RayGPUWorker, worker_kwargs)
if not has_event_loop():
self.init_workers_sync()

except Exception as e:
self.shutdown()
logger.error(f"Failed to initialize RayExecutor: {e}")
raise e

def create_workers(self, worker_cls, worker_kwargs):
llm_args = worker_kwargs.get("llm_args")

# When set to be a fraction, it allows Ray to schedule
# multiple actors on a single GPU for colocate use cases.
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
num_gpus = (llm_args.per_worker_gpu_share if llm_args
and llm_args.per_worker_gpu_share is not None else float(
os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0")))
logger.debug(f"{num_gpus=} for each worker.")

runtime_env = ray.runtime_env.RuntimeEnv()
Expand All @@ -128,28 +136,40 @@ def create_workers(self, worker_cls, worker_kwargs):
"MASTER_PORT": str(self.master_port)
})

self.placement_group, self.bundle_indices = self._get_placement_group(
tp_size=self.tp_size)
placement_groups, self.bundle_indices = self._get_placement_group(
tp_size=self.tp_size, worker_kwargs=worker_kwargs)

self.workers = [
RayWorkerWrapper.options(
if isinstance(placement_groups, list):
self.placement_group = None
else:
self.placement_group = placement_groups

self.workers = []
for rank in range(self.world_size):
pg = placement_groups[rank] if isinstance(
placement_groups, list) else placement_groups
worker = RayWorkerWrapper.options(
num_gpus=num_gpus,
runtime_env=runtime_env, # per-actor env
runtime_env=runtime_env,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=self.placement_group,
placement_group=pg,
placement_group_bundle_index=self.bundle_indices[rank],
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
for rank in range(self.world_size)
]
self.workers.append(worker)

def init_workers_sync(self):
self.create_workers(RayGPUWorker, self.worker_kwargs)
try:
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
ray.get(self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
if "The actor died because of an error raised in its creation task" in str(
e):
raise RuntimeError(
"RayGPUWorker died during initialization") from e
raise
raise RuntimeError("RayGPUWorker died during initialization") from e

async def init_workers_async(self):
self.create_workers(RayGPUWorker, self.worker_kwargs)
try:
await asyncio.gather(*self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
raise RuntimeError("RayGPUWorker died during initialization") from e

@unwrap_ray_errors()
def call_all_ray_workers(self, func: str, leader_only: bool,
Expand Down Expand Up @@ -189,6 +209,21 @@ def collective_rpc(self,
**kwargs))
return refs if non_block else ray.get(refs)

@unwrap_ray_errors()
async def collective_rpc_async(
self,
method: str,
args: tuple = (),
kwargs: Optional[dict] = None,
unique_reply_rank: Optional[int] = None) -> list[Any]:
refs = self.collective_rpc(method,
args,
kwargs,
non_block=True,
unique_reply_rank=unique_reply_rank)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, ray.get, refs)

def submit(self, request: "GenerationRequest") -> "GenerationResult":
"""
Low-level API to the executor. Return a "future" GenerationResult
Expand Down Expand Up @@ -316,15 +351,51 @@ def shutdown(self):
logger.debug("Shutting down Ray cluster")
ray.shutdown()

def _get_placement_group(self,
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
def _get_worker_ready_futures(self):
return [worker.__ray_ready__.remote() for worker in self.workers]

def _get_placement_group(
self,
tp_size: int,
worker_kwargs: Dict = None) -> Tuple[Any, List[int]]:
"""
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
or create a default PACK placement group where each bundle has tp_size GPUs.
- When tp_size ≤ GPUs per node, keep one TP group per node.
- When tp_size > GPUs per node, allow a TP group span nodes.
- rank 0 must be put on the driver node

Returns:
Tuple of (placement_group(s), bundle_indices)
- placement_group(s) can be a single PlacementGroup or a List[PlacementGroup]
- bundle_indices is always a List[int]
"""
llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None

if llm_args and hasattr(
llm_args,
'placement_groups') and llm_args.placement_groups is not None:
total_workers = sum(
len(indices) for indices in llm_args.placement_bundle_indices)
if total_workers != self.world_size:
raise ValueError(
f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})"
)

logger.info(
f"Creating {self.world_size} workers with external placement groups"
)

flat_pgs = []
flat_indices = []
for pg, indices in zip(llm_args.placement_groups,
llm_args.placement_bundle_indices):
for idx in indices:
flat_pgs.append(pg)
flat_indices.append(idx)

return flat_pgs, flat_indices

bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)

if bundle_indices:
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/executor/ray_gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class RayWorkerWrapper:
def __init__(self, worker_cls, worker_kwargs, world_size, rank):
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]

# Ray can't pickle TensorRT logger
global logger
from tensorrt_llm.logger import logger
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/llmapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..executor import CompletionOutput, LoRARequest, RequestError
from ..sampling_params import GuidedDecodingParams, SamplingParams
from .build_cache import BuildCacheConfig
from .llm import LLM, RequestOutput
from .llm import LLM, AsyncLLM, RequestOutput
# yapf: disable
from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
CacheTransceiverConfig, CalibConfig,
Expand All @@ -23,6 +23,7 @@

__all__ = [
'LLM',
'AsyncLLM',
'MultimodalEncoder',
'CompletionOutput',
'RequestOutput',
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
from .utils import (append_docstring, exception_handler, get_device_count,
logger_debug, set_api_status)
from ray.util.placement_group import PlacementGroup, placement_group


class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
Expand Down Expand Up @@ -189,7 +190,7 @@ def __init__(self,
self.mpi_session = self.args.mpi_session

if self.args.parallel_config.is_multi_gpu:
if get_device_count(
if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count(
) < self.args.parallel_config.world_size_per_node:
raise RuntimeError(
f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required."
Expand Down Expand Up @@ -225,7 +226,6 @@ def __init__(self,

self.runtime_context: Optional[_ModelRuntimeContext] = None
self.llm_build_stats = LlmBuildStats()

self._build_model()

except Exception:
Expand Down Expand Up @@ -1125,3 +1125,10 @@ def __init__(self,

Parameters:
""" + TORCH_LLM_DOCSTRING

class AsyncLLM(LLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def async_init_phase(self):
await self._executor.init_workers_async()
Loading
Loading