Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
647 changes: 647 additions & 0 deletions examples/llm-api/rl_integration_test_async.py

Large diffs are not rendered by default.

722 changes: 722 additions & 0 deletions examples/llm-api/rl_integration_test_async_pg.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _preload_python_lib():
from .builder import BuildConfig, Builder, BuilderConfig, build
from .disaggregated_params import DisaggregatedParams
from .functional import Tensor, constant
from .llmapi import LLM, MultimodalEncoder
from .llmapi import LLM, AsyncLLM, MultimodalEncoder
from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs
from .logger import logger
from .mapping import Mapping
Expand Down Expand Up @@ -136,6 +136,7 @@ def _preload_python_lib():
'quantization',
'tools',
'LLM',
'AsyncLLM',
'MultimodalEncoder',
'LlmArgs',
'TorchLlmArgs',
Expand Down
106 changes: 81 additions & 25 deletions tensorrt_llm/executor/ray_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -7,8 +8,7 @@
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
raise

from ray.util.placement_group import (PlacementGroup,
PlacementGroupSchedulingStrategy,
from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
get_current_placement_group,
placement_group)

Expand All @@ -23,6 +23,7 @@
from .request import GenerationRequest
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
from .rpc_proxy import RpcExecutorMixin
from .utils import has_event_loop

__all__ = [
"RayExecutor",
Expand Down Expand Up @@ -78,14 +79,16 @@ def __init__(self,
self.master_port = get_free_port()
self.use_rpc = ray_use_rpc()

worker_kwargs = dict(**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
self.worker_kwargs = dict(
**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)

if self.use_rpc:
self.init_rpc_executor()
worker_kwargs['rpc_addr'] = self.rpc_addr
self.create_workers(RayGPUWorker, worker_kwargs)
self.worker_kwargs['rpc_addr'] = self.rpc_addr
if not has_event_loop():
self.init_workers_sync()
self.setup_engine_remote()
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
Expand All @@ -107,17 +110,22 @@ def __init__(self,
self.response_sync_queue)
self.response_queue.warmup.remote()
self.response_sync_queue.warmup.remote()
self.create_workers(RayGPUWorker, worker_kwargs)
if not has_event_loop():
self.init_workers_sync()

except Exception as e:
self.shutdown()
logger.error(f"Failed to initialize RayExecutor: {e}")
raise e

def create_workers(self, worker_cls, worker_kwargs):
llm_args = worker_kwargs.get("llm_args")

# When set to be a fraction, it allows Ray to schedule
# multiple actors on a single GPU for colocate use cases.
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
num_gpus = (llm_args.per_worker_gpu_share if llm_args
and llm_args.per_worker_gpu_share is not None else float(
os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0")))
logger.debug(f"{num_gpus=} for each worker.")

runtime_env = ray.runtime_env.RuntimeEnv()
Expand All @@ -128,28 +136,40 @@ def create_workers(self, worker_cls, worker_kwargs):
"MASTER_PORT": str(self.master_port)
})

self.placement_group, self.bundle_indices = self._get_placement_group(
tp_size=self.tp_size)
placement_groups, self.bundle_indices = self._get_placement_group(
tp_size=self.tp_size, worker_kwargs=worker_kwargs)

if isinstance(placement_groups, list):
self.placement_group = None
else:
self.placement_group = placement_groups

self.workers = [
RayWorkerWrapper.options(
self.workers = []
for rank in range(self.world_size):
pg = placement_groups[rank] if isinstance(
placement_groups, list) else placement_groups
worker = RayWorkerWrapper.options(
num_gpus=num_gpus,
runtime_env=runtime_env, # per-actor env
runtime_env=runtime_env,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=self.placement_group,
placement_group=pg,
placement_group_bundle_index=self.bundle_indices[rank],
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
for rank in range(self.world_size)
]
self.workers.append(worker)

def init_workers_sync(self):
self.create_workers(RayGPUWorker, self.worker_kwargs)
try:
ray.get(self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
raise RuntimeError("RayGPUWorker died during initialization") from e

async def init_workers_async(self):
self.create_workers(RayGPUWorker, self.worker_kwargs)
try:
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
await asyncio.gather(*self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
if "The actor died because of an error raised in its creation task" in str(
e):
raise RuntimeError(
"RayGPUWorker died during initialization") from e
raise
raise RuntimeError("RayGPUWorker died during initialization") from e

@unwrap_ray_errors()
def call_all_ray_workers(self, func: str, leader_only: bool,
Expand Down Expand Up @@ -316,15 +336,51 @@ def shutdown(self):
logger.debug("Shutting down Ray cluster")
ray.shutdown()

def _get_placement_group(self,
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
def _get_worker_ready_futures(self):
return [worker.__ray_ready__.remote() for worker in self.workers]

def _get_placement_group(
self,
tp_size: int,
worker_kwargs: Dict = None) -> Tuple[Any, List[int]]:
"""
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
or create a default PACK placement group where each bundle has tp_size GPUs.
- When tp_size ≤ GPUs per node, keep one TP group per node.
- When tp_size > GPUs per node, allow a TP group span nodes.
- rank 0 must be put on the driver node

Returns:
Tuple of (placement_group(s), bundle_indices)
- placement_group(s) can be a single PlacementGroup or a List[PlacementGroup]
- bundle_indices is always a List[int]
"""
llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None

if llm_args and hasattr(
llm_args,
'placement_groups') and llm_args.placement_groups is not None:
total_workers = sum(
len(indices) for indices in llm_args.placement_bundle_indices)
if total_workers != self.world_size:
raise ValueError(
f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})"
)

logger.info(
f"Creating {self.world_size} workers with external placement groups"
)

flat_pgs = []
flat_indices = []
for pg, indices in zip(llm_args.placement_groups,
llm_args.placement_bundle_indices):
for idx in indices:
flat_pgs.append(pg)
flat_indices.append(idx)

return flat_pgs, flat_indices

bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)

if bundle_indices:
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/executor/ray_gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class RayWorkerWrapper:
def __init__(self, worker_cls, worker_kwargs, world_size, rank):
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]

# Ray can't pickle TensorRT logger
global logger
from tensorrt_llm.logger import logger
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/llmapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..executor import CompletionOutput, LoRARequest, RequestError
from ..sampling_params import GuidedDecodingParams, SamplingParams
from .build_cache import BuildCacheConfig
from .llm import LLM, RequestOutput
from .llm import LLM, AsyncLLM, RequestOutput
# yapf: disable
from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
CacheTransceiverConfig, CalibConfig,
Expand All @@ -23,6 +23,7 @@

__all__ = [
'LLM',
'AsyncLLM',
'MultimodalEncoder',
'CompletionOutput',
'RequestOutput',
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
from .utils import (append_docstring, exception_handler, get_device_count,
logger_debug, set_api_status)
from ray.util.placement_group import PlacementGroup, placement_group


class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
Expand Down Expand Up @@ -189,7 +190,7 @@ def __init__(self,
self.mpi_session = self.args.mpi_session

if self.args.parallel_config.is_multi_gpu:
if get_device_count(
if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count(
) < self.args.parallel_config.world_size_per_node:
raise RuntimeError(
f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required."
Expand Down Expand Up @@ -225,7 +226,6 @@ def __init__(self,

self.runtime_context: Optional[_ModelRuntimeContext] = None
self.llm_build_stats = LlmBuildStats()

self._build_model()

except Exception:
Expand Down Expand Up @@ -1125,3 +1125,10 @@ def __init__(self,

Parameters:
""" + TORCH_LLM_DOCSTRING

class AsyncLLM(LLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def async_init_phase(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notes from syncing w/ Liwei:

VeRL needs async init for TRTLLM's LLM(), but PYthon has a limitation where init must be sync. So Liwei separate the async part out here.

await self._executor.init_workers_async()
70 changes: 68 additions & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from dataclasses import dataclass
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple,
Type, TypeAlias, TypeVar, Union, get_args, get_origin)
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
Set, Tuple, Type, TypeAlias, TypeVar, Union, get_args,
get_origin)

import torch
import yaml
Expand All @@ -19,6 +20,11 @@
from strenum import StrEnum
from transformers import PreTrainedTokenizerBase

try:
from ray.util.placement_group import PlacementGroup
except ImportError:
PlacementGroup = None

from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)

Expand Down Expand Up @@ -1926,6 +1932,8 @@ def validate_dtype(cls, v, info):
@field_validator("gpus_per_node", mode='before')
@classmethod
def validate_gpus_per_node(cls, v, info):
if os.getenv("RAY_LOCAL_WORLD_SIZE") is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: According to Liwei, this is likely obsolete.

return info.data.get("tensor_parallel_size")
if v is None:
logger.warning(
f"Using default gpus_per_node: {torch.cuda.device_count()}")
Expand Down Expand Up @@ -2701,6 +2709,26 @@ class TorchLlmArgs(BaseLlmArgs):
"Allows users to extend the functions of the RayGPUWorker class.",
status="prototype")

# Ray placement group config. Namings TBD.
placement_groups: Optional[List[Any]] = Field(
default=None,
description="List of Ray placement groups, one per node. "
"Each element must be a ray.util.placement_group.PlacementGroup instance.",
exclude_from_json=True,
status="prototype")

placement_bundle_indices: Optional[List[List[int]]] = Field(
default=None,
description="List of bundle indices for each placement group. "
"Outer list corresponds to placement_groups, inner list contains bundle indices for that group. ",
status="prototype")

per_worker_gpu_share: Optional[float] = Field(
default=None,
description="GPU fraction per worker for colocation scenarios. "
"Example: 0.1 means 10 actors can share one GPU. Defaults to 1.0 (one actor per GPU).",
status="prototype")

enable_sleep: bool = Field(
default=False,
description=
Expand Down Expand Up @@ -2945,6 +2973,44 @@ def validate_ray_worker_extension_cls(self) -> 'TorchLlmArgs':
)
return self

@model_validator(mode='after')
def validate_ray_placement_config(self) -> 'TorchLlmArgs':
has_pgs = self.placement_groups is not None
has_indices = self.placement_bundle_indices is not None

if (has_pgs or has_indices) and self.orchestrator_type != "ray":
raise ValueError(
"placement_groups is only supported with orchestrator_type='ray'"
)

if has_pgs != has_indices:
raise ValueError(
"placement_groups and placement_bundle_indices must be provided together"
)

if has_pgs:
if len(self.placement_groups) != len(self.placement_bundle_indices):
raise ValueError(
f"placement_groups length ({len(self.placement_groups)}) must equal "
f"placement_bundle_indices length ({len(self.placement_bundle_indices)})"
)

if self.per_worker_gpu_share is not None:
if not (0 < self.per_worker_gpu_share <= 1.0):
raise ValueError(
f"per_worker_gpu_share must be between 0 and 1.0, "
f"got {self.per_worker_gpu_share}")

if has_pgs:
if PlacementGroup is not None:
for i, pg in enumerate(self.placement_groups):
if not isinstance(pg, PlacementGroup):
raise TypeError(
f"placement_groups[{i}] must be a Ray PlacementGroup, "
f"got {type(pg).__name__}")

return self

def get_executor_config(
self,
_hf_model_dir: Optional[Path] = None,
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/llmapi/rlhf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.logger import logger
import pickle
import base64


class WorkerExtension:
Expand Down Expand Up @@ -52,7 +54,7 @@ def update_weights(self, ipc_handles: dict):
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")

weights = {}
all_handles = ipc_handles[device_uuid]
all_handles = pickle.loads(base64.b64decode(ipc_handles[device_uuid]))

for param_name, tensor_handle in all_handles:
func, args = tensor_handle
Expand Down
Loading
Loading