diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 978cf0796f1..cea56431b77 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -84,7 +84,7 @@ def _preload_python_lib(): from .builder import BuildConfig, Builder, BuilderConfig, build from .disaggregated_params import DisaggregatedParams from .functional import Tensor, constant -from .llmapi import LLM, MultimodalEncoder +from .llmapi import LLM, AsyncLLM, MultimodalEncoder from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs from .logger import logger from .mapping import Mapping @@ -136,6 +136,7 @@ def _preload_python_lib(): 'quantization', 'tools', 'LLM', + 'AsyncLLM', 'MultimodalEncoder', 'LlmArgs', 'TorchLlmArgs', diff --git a/tensorrt_llm/_torch/async_llm.py b/tensorrt_llm/_torch/async_llm.py new file mode 100644 index 00000000000..76c33220daf --- /dev/null +++ b/tensorrt_llm/_torch/async_llm.py @@ -0,0 +1,106 @@ +from typing import Any, List, Optional + +from ..llmapi.llm import LLM +from ..llmapi.llm_args import RayPlacementConfig + + +class AsyncLLM(LLM): + """AsyncLLM is a subclass of LLM that supports asynchronous setup, release and + resume operations that are necessary for RL or agentic scenarios. + + Currently, RL APIs are only supported with Ray orchestrator. + """ + + def __init__( + self, + placement_groups: Optional[List[Any]] = None, + placement_bundle_indices: Optional[List[List[int]]] = None, + per_worker_gpu_share: Optional[float] = None, + *args, + **kwargs, + ): + kwargs["orchestrator_type"] = "ray" + kwargs["ray_placement_config"] = RayPlacementConfig( + defer_workers_init=True, + placement_groups=placement_groups, + placement_bundle_indices=placement_bundle_indices, + per_worker_gpu_share=per_worker_gpu_share, + ) + + # WAR: RL integration needs to use NCCL AllReduce for TP>1 due to a bug in TRTLLM's AllReduce + # which will cause convergence issue when using multiple rollout instances. + kwargs["allreduce_strategy"] = "NCCL" + + if "ray_worker_extension_cls" not in kwargs: + kwargs["ray_worker_extension_cls"] = "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension" + + super().__init__(*args, **kwargs) + self._async_initialized = False + + async def setup_async(self): + """Setup the LLM asynchronously.""" + if not self._async_initialized: + await self._executor.init_workers_async() + await self._executor.setup_engine_remote_async() + self._async_initialized = True + return self + + async def release(self, tags: list[str]): + """Release the GPU memory used by the LLM asynchronously. + + Args: + tags: List of memory tag strings to release (e.g., ["model", "kv_cache"]). + """ + await self.collective_rpc("sleep", args=(tags,)) + + async def resume(self, tags: list[str]): + """Resume the GPU memory used by the LLM asynchronously. + + Args: + tags: List of memory tag strings to resume (e.g., ["model", "kv_cache"]). + """ + await self.collective_rpc("wakeup", args=(tags,)) + + async def update_weights(self, weights: dict[str, str]): + """Update the weights of the LLM asynchronously. + + + Args: + weights: Dictionary mapping device UUIDs to IPC handles for weight tensors. + """ + await self.collective_rpc("update_weights", args=(weights,)) + + async def collective_rpc( + self, + method: str, + args: tuple[Any, ...] = (), + kwargs: Optional[dict] = None, + unique_reply_rank: Optional[int] = None, + ) -> list[Any]: + """Execute an asynchronous RPC call on all GPU workers. Currently, this is only supported for RayExecutor. + + Args: + method (str): The name of the worker method to execute. + args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to (). + kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None. + unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply. + + Returns: + list[Any]: A list of results from each worker. + """ + return await self._executor.collective_rpc_async( + method, args, kwargs, unique_reply_rank=unique_reply_rank + ) + + def __await__(self): + return self.setup_async().__await__() + + def __enter__(self): + raise RuntimeError("Please use 'async with AsyncLLM' instead") + + async def __aenter__(self): + await self.setup_async() + return super().__enter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return super().__exit__(exc_type, exc_val, exc_tb) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 09d69bb126a..94cda079b0b 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -2897,7 +2897,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM): new_tokens_host = state.host.new_tokens.flatten().tolist() sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() finish_reasons = state.host.finish_reasons.flatten().tolist() - log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None + log_probs_host_tensor = state.host.log_probs cum_log_probs_host = ( state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None ) @@ -2919,24 +2919,31 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM): add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0) # Log probs - for request in reqs_with_new_tokens: - if request.py_return_log_probs: - seq_slot = request.py_seq_slot - seq_len = sequence_lengths_host_data[seq_slot] - begin_log_probs_offset = request.prompt_len - current_token = seq_len - request.prompt_len - 1 - log_probs = [ - { - new_tokens_host[seq_slot]: Logprob( - logprob=log_probs_host[seq_slot][0][ - begin_log_probs_offset + current_token - ], - rank=1, - ) - } - ] - cum_log_probs = [cum_log_probs_host[seq_slot]] - request.py_result.append_log_probs([log_probs], cum_log_probs) + if log_probs_host_tensor is not None: + # Log probs + seq_slots = [] + seq_lens = [] + for request in reqs_with_new_tokens: + if request.py_return_log_probs: + seq_slot = request.py_seq_slot + seq_slots.append(seq_slot) + seq_lens.append(sequence_lengths_host_data[seq_slot] - 1) + + log_probs_host = log_probs_host_tensor[seq_slots, 0, seq_lens].tolist() + idx = 0 + for request in reqs_with_new_tokens: + if request.py_return_log_probs: + log_probs = [ + { + new_tokens_host[seq_slot]: Logprob( + logprob=log_probs_host[idx], + rank=1, + ) + } + ] + cum_log_probs = [cum_log_probs_host[seq_slot]] + request.py_result.append_log_probs([log_probs], cum_log_probs) + idx += 1 for request in reqs: request.py_decoding_iter += 1 diff --git a/tensorrt_llm/_torch/virtual_memory.py b/tensorrt_llm/_torch/virtual_memory.py index 3702d732539..7efdd60c35b 100644 --- a/tensorrt_llm/_torch/virtual_memory.py +++ b/tensorrt_llm/_torch/virtual_memory.py @@ -74,7 +74,8 @@ class ExecutorMemoryType(StrEnum): SPEC_RESOURCES = "spec_resource_manager" INIT_KV_CACHE = "_no_capture_init_kv_cache" INIT_EXTRA_RESOURCES = "_no_capture_init_extra_resources" - MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache() + # MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache() + MODEL_EXTRA = "model_extra" EXTRA_RESOURCES = "executor_extra" KV_CACHE = "kv_cache" MODEL_ENGINE_MAIN = "model" diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index 579aac0a715..e03f524bea4 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import Any, Dict, List, Optional, Tuple @@ -7,8 +8,7 @@ e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator""" raise -from ray.util.placement_group import (PlacementGroup, - PlacementGroupSchedulingStrategy, +from ray.util.placement_group import (PlacementGroupSchedulingStrategy, get_current_placement_group, placement_group) @@ -23,6 +23,7 @@ from .request import GenerationRequest from .result import GenerationResult from .rpc_proxy_mixin import RpcExecutorMixin +from .utils import has_event_loop __all__ = [ "RayExecutor", @@ -77,19 +78,30 @@ def __init__(self, self.master_address = ray.util.get_node_ip_address() self.master_port = get_free_port() - worker_kwargs = dict(**worker_kwargs, - postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor) + self.worker_kwargs = dict( + **worker_kwargs, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor) self.init_rpc_executor() # Inject the generated HMAC key into worker_kwargs for workers - worker_kwargs['hmac_key'] = self.hmac_key - worker_kwargs['rpc_addr'] = self.rpc_addr - self.create_workers(RayGPUWorker, worker_kwargs) - self.setup_engine_remote() - self.setup_mainloop(tasks=[self._fetch_responses_loop_async], - thread_name="ray_executor_main_loop") - logger.info(f"Connecting to RPC server at {self.rpc_addr}") + self.worker_kwargs['hmac_key'] = self.hmac_key + self.worker_kwargs['rpc_addr'] = self.rpc_addr + + placement_config = getattr(self.worker_kwargs['llm_args'], + 'ray_placement_config', None) + defer_workers_init = placement_config.defer_workers_init if placement_config else False + + if defer_workers_init: + self.workers = [ + ] # Placeholder, will be initialized in setup_async + self._mainloop_started = False # DO NOT start mainloop until after setup_engine_remote_async is called + else: + if not has_event_loop(): + self.init_workers_sync() + self.setup_engine_remote() + self.setup_mainloop(tasks=[self._fetch_responses_loop_async], + thread_name="ray_executor_main_loop") except Exception as e: self.shutdown() @@ -97,9 +109,16 @@ def __init__(self, raise e def create_workers(self, worker_cls, worker_kwargs): + llm_args = worker_kwargs.get("llm_args") + placement_config = getattr(llm_args, 'ray_placement_config', + None) if llm_args else None + # When set to be a fraction, it allows Ray to schedule # multiple actors on a single GPU for colocate use cases. num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0")) + if placement_config and placement_config.per_worker_gpu_share is not None: + num_gpus = placement_config.per_worker_gpu_share + logger.debug(f"{num_gpus=} for each worker.") runtime_env = ray.runtime_env.RuntimeEnv() @@ -110,28 +129,40 @@ def create_workers(self, worker_cls, worker_kwargs): "MASTER_PORT": str(self.master_port) }) - self.placement_group, self.bundle_indices = self._get_placement_group( - tp_size=self.tp_size) + placement_groups, self.bundle_indices = self._get_placement_group( + tp_size=self.tp_size, worker_kwargs=worker_kwargs) - self.workers = [ - RayWorkerWrapper.options( + if isinstance(placement_groups, list): + self.placement_group = None + else: + self.placement_group = placement_groups + + self.workers = [] + for rank in range(self.world_size): + pg = placement_groups[rank] if isinstance( + placement_groups, list) else placement_groups + worker = RayWorkerWrapper.options( num_gpus=num_gpus, - runtime_env=runtime_env, # per-actor env + runtime_env=runtime_env, scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=self.placement_group, + placement_group=pg, placement_group_bundle_index=self.bundle_indices[rank], )).remote(worker_cls, worker_kwargs, self.world_size, rank) - for rank in range(self.world_size) - ] + self.workers.append(worker) + def init_workers_sync(self): + self.create_workers(RayGPUWorker, self.worker_kwargs) try: - ray.get([worker.__ray_ready__.remote() for worker in self.workers]) + ray.get(self._get_worker_ready_futures()) except ray.exceptions.ActorDiedError as e: - if "The actor died because of an error raised in its creation task" in str( - e): - raise RuntimeError( - "RayGPUWorker died during initialization") from e - raise + raise RuntimeError("RayGPUWorker died during initialization") from e + + async def init_workers_async(self): + self.create_workers(RayGPUWorker, self.worker_kwargs) + try: + await asyncio.gather(*self._get_worker_ready_futures()) + except ray.exceptions.ActorDiedError as e: + raise RuntimeError("RayGPUWorker died during initialization") from e @unwrap_ray_errors() def call_all_ray_workers(self, func: str, leader_only: bool, @@ -171,6 +202,20 @@ def collective_rpc(self, **kwargs)) return refs if non_block else ray.get(refs) + @unwrap_ray_errors() + async def collective_rpc_async( + self, + method: str, + args: tuple = (), + kwargs: Optional[dict] = None, + unique_reply_rank: Optional[int] = None) -> list[Any]: + refs = self.collective_rpc(method, + args, + kwargs, + non_block=True, + unique_reply_rank=unique_reply_rank) + return await asyncio.gather(*refs) + def submit(self, request: "GenerationRequest") -> "GenerationResult": """ Low-level API to the executor. Return a "future" GenerationResult @@ -198,6 +243,26 @@ def start(self): def setup_engine_remote(self): return self.collective_rpc("setup_engine", non_block=False) + async def setup_engine_remote_async(self): + """Async version of setup_engine_remote for use after async worker initialization.""" + if not self.workers or len(self.workers) == 0: + raise RuntimeError( + "Workers must be initialized before calling setup_engine_remote_async" + ) + + # Setup engine on all workers + result = await self.collective_rpc_async("setup_engine") + logger.info("setup_engine_remote_async finished") + + # Now that engine is set up, start the mainloop for fetching responses + if hasattr(self, '_mainloop_started') and not self._mainloop_started: + logger.info("Starting mainloop after engine setup") + self.setup_mainloop(tasks=[self._fetch_responses_loop_async], + thread_name="ray_executor_main_loop") + self._mainloop_started = True + + return result + def report_device_ids(self) -> list[str]: gpu_ids = self.call_all_ray_workers("report_device_id", leader_only=False, @@ -265,15 +330,52 @@ def shutdown(self): logger.debug("Shutting down Ray cluster") ray.shutdown() - def _get_placement_group(self, - tp_size: int) -> Tuple[PlacementGroup, List[int]]: + def _get_worker_ready_futures(self): + return [worker.__ray_ready__.remote() for worker in self.workers] + + def _get_placement_group( + self, + tp_size: int, + worker_kwargs: Dict = None) -> Tuple[Any, List[int]]: """ Either use the existing placement group from driver script (e.g., in the case of RL FW integration), or create a default PACK placement group where each bundle has tp_size GPUs. - When tp_size ≤ GPUs per node, keep one TP group per node. - When tp_size > GPUs per node, allow a TP group span nodes. - rank 0 must be put on the driver node + + Returns: + Tuple of (placement_group(s), bundle_indices) + - placement_group(s) can be a single PlacementGroup or a List[PlacementGroup] + - bundle_indices is always a List[int] """ + llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None + + placement_config = getattr(llm_args, 'ray_placement_config', + None) if llm_args else None + if placement_config and placement_config.placement_groups is not None: + total_workers = sum( + len(indices) + for indices in placement_config.placement_bundle_indices) + if total_workers != self.world_size: + raise ValueError( + f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})" + ) + + logger.info( + f"Creating {self.world_size} workers with external placement groups" + ) + + flat_pgs = [] + flat_indices = [] + for pg, indices in zip(placement_config.placement_groups, + placement_config.placement_bundle_indices): + for idx in indices: + flat_pgs.append(pg) + flat_indices.append(idx) + + return flat_pgs, flat_indices + bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None) if bundle_indices: diff --git a/tensorrt_llm/executor/ray_gpu_worker.py b/tensorrt_llm/executor/ray_gpu_worker.py index 48f036abeb0..fca5386cb54 100644 --- a/tensorrt_llm/executor/ray_gpu_worker.py +++ b/tensorrt_llm/executor/ray_gpu_worker.py @@ -1,3 +1,4 @@ +import gc import importlib import os from pathlib import Path @@ -43,7 +44,6 @@ class RayWorkerWrapper: def __init__(self, worker_cls, worker_kwargs, world_size, rank): self.master_address = os.environ["MASTER_ADDR"] self.master_port = os.environ["MASTER_PORT"] - # Ray can't pickle TensorRT logger global logger from tensorrt_llm.logger import logger @@ -218,6 +218,8 @@ def sleep(self, sleep_tags: List[str]): torch.cuda.synchronize() release_with_tag(*tags) torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() except Exception as e: logger.error(f"Encountered an error in sleep: {e}") raise e diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index cb868d8d068..8563b9090c8 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -1,3 +1,4 @@ +from .._torch.async_llm import AsyncLLM from ..disaggregated_params import DisaggregatedParams from ..executor import CompletionOutput, LoRARequest, RequestError from ..sampling_params import GuidedDecodingParams, SamplingParams @@ -23,6 +24,7 @@ __all__ = [ 'LLM', + 'AsyncLLM', 'MultimodalEncoder', 'CompletionOutput', 'RequestOutput', diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 41c9bdeeaee..33774f0ed8f 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -193,7 +193,7 @@ def __init__(self, self.mpi_session = self.args.mpi_session if self.args.parallel_config.is_multi_gpu: - if get_device_count( + if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count( ) < self.args.parallel_config.world_size_per_node: raise RuntimeError( f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required." @@ -229,7 +229,6 @@ def __init__(self, self.runtime_context: Optional[_ModelRuntimeContext] = None self.llm_build_stats = LlmBuildStats() - self._build_model() except Exception: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9f154c53f63..6720a2df321 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -19,6 +19,11 @@ from strenum import StrEnum from transformers import PreTrainedTokenizerBase +try: + from ray.util.placement_group import PlacementGroup +except ImportError: + PlacementGroup = None + from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) @@ -1086,6 +1091,65 @@ def supports_backend(self, backend: str) -> bool: return backend == "pytorch" +class RayPlacementConfig(StrictBaseModel): + """ + Configuration for Ray GPU workers placement. + This config is only used with AsyncLLM for RL scenarios. + """ + defer_workers_init: bool = Field( + default=False, + description="Defer Ray worker initialization until async setup.") + + placement_groups: Optional[List[Any]] = Field( + default=None, + description="List of Ray placement groups, one per node. " + "Each element must be a ray.util.placement_group.PlacementGroup instance." + ) + + placement_bundle_indices: Optional[List[List[int]]] = Field( + default=None, + description="List of bundle indices for each placement group. " + "Outer list corresponds to placement_groups, inner list contains bundle indices for that group." + ) + + per_worker_gpu_share: Optional[float] = Field( + default=None, + description="GPU fraction per worker for colocation scenarios. " + "Example: 0.1 means 10 actors can share one GPU. Defaults to 1.0 (one actor per GPU)." + ) + + @model_validator(mode='after') + def validate_ray_placement(self) -> 'RayPlacementConfig': + has_pgs = self.placement_groups is not None + has_indices = self.placement_bundle_indices is not None + + if has_pgs != has_indices: + raise ValueError( + "placement_groups and placement_bundle_indices must be provided together" + ) + + if has_pgs: + if len(self.placement_groups) != len(self.placement_bundle_indices): + raise ValueError( + f"placement_groups length ({len(self.placement_groups)}) must equal " + f"placement_bundle_indices length ({len(self.placement_bundle_indices)})" + ) + if PlacementGroup is not None: + for i, pg in enumerate(self.placement_groups): + if not isinstance(pg, PlacementGroup): + raise TypeError( + f"placement_groups[{i}] must be a Ray PlacementGroup, " + f"got {type(pg).__name__}") + + if self.per_worker_gpu_share is not None: + if not (0 < self.per_worker_gpu_share <= 1.0): + raise ValueError( + f"per_worker_gpu_share must be between 0 and 1.0, " + f"got {self.per_worker_gpu_share}") + + return self + + class PybindMirror(ABC): ''' A class containing the utilities for mirroring Python classes to pybinding classes. @@ -2032,6 +2096,8 @@ def validate_dtype(cls, v, info): @field_validator("gpus_per_node", mode='before') @classmethod def validate_gpus_per_node(cls, v, info): + if os.getenv("RAY_LOCAL_WORLD_SIZE") is not None: + return info.data.get("tensor_parallel_size") if v is None: logger.warning( f"Using default gpus_per_node: {torch.cuda.device_count()}") @@ -2741,6 +2807,13 @@ class TorchLlmArgs(BaseLlmArgs): "Allows users to extend the functions of the RayGPUWorker class.", status="prototype") + ray_placement_config: Optional[RayPlacementConfig] = Field( + default=None, + description= + "Placement config for RayGPUWorker. Only used with AsyncLLM and orchestrator_type='ray'.", + exclude=True, + status="prototype") + enable_sleep: bool = Field( default=False, description= @@ -3050,6 +3123,14 @@ def validate_ray_worker_extension_cls(self) -> 'TorchLlmArgs': ) return self + @model_validator(mode='after') + def validate_ray_placement_config(self) -> 'TorchLlmArgs': + if self.ray_placement_config is not None and self.orchestrator_type != "ray": + raise ValueError( + "ray_placement_config is only supported with orchestrator_type='ray'" + ) + return self + def get_executor_config( self, _hf_model_dir: Optional[Path] = None, diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index 4934d40e979..ce6eaa5b4ff 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -1,3 +1,5 @@ +import base64 +import pickle # nosec B403 from typing import Optional import torch @@ -56,12 +58,20 @@ def update_weights(self, ipc_handles: Optional[dict] = None): raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") weights = {} - all_handles = ipc_handles[device_uuid] + + serialized_handles = ipc_handles[device_uuid] + if isinstance(serialized_handles, str): + # Data is base64-encoded pickled bytes - deserialize it + logger.info("Deserializing base64-encoded weight handles") + all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301 + else: + # Data is already in the correct format (backward compatibility) + all_handles = serialized_handles for param_name, tensor_handle in all_handles: func, args = tensor_handle list_args = list(args) - list_args[6] = self.device_id # Set target device + list_args[6] = self.device_id tensor = func(*list_args) weights[param_name] = tensor @@ -88,7 +98,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): logger.error("Encountered an error in update_weights") raise e - def check_weights_updated(self): + def check_weights_updated(self) -> bool: """Check if the weights are updated to 0.""" weights_updated = True for name, p in self.engine.model_engine.model.named_parameters(): diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 9dc837810e0..8ddda27cd7f 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -968,6 +968,16 @@ class ResponsesStreamResponse(OpenAIBaseModel): "response.incomplete"] +class MemoryUpdateRequest(OpenAIBaseModel): + tags: List[str] = Field(default=["model", "kv_cache"]) + + +class UpdateWeightsRequest(OpenAIBaseModel): + weights: Optional[Dict[str, str]] = Field( + default=None, + description="Weight handles dict, or None to finalize update") + + def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]: if opaque_state is None: return None diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index e64c5d20df6..3811c8a12e4 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -21,6 +21,7 @@ from transformers import AutoProcessor from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm._torch.async_llm import AsyncLLM # yapf: disable from tensorrt_llm.executor import CppExecutorError from tensorrt_llm.executor.postproc_worker import PostprocParams @@ -46,9 +47,11 @@ ChatMessage, CompletionRequest, CompletionResponse, CompletionResponseChoice, - ErrorResponse, ModelCard, + ErrorResponse, + MemoryUpdateRequest, ModelCard, ModelList, PromptTokensDetails, - ResponsesRequest, UsageInfo, + ResponsesRequest, + UpdateWeightsRequest, UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs, @@ -262,6 +265,16 @@ def register_routes(self): self.app.add_api_route("/v1/responses", self.openai_responses, methods=["POST"]) + # RL-only endpoints + self.app.add_api_route("/release_memory", + self.release_memory, + methods=["POST"]) + self.app.add_api_route("/resume_memory", + self.resume_memory, + methods=["POST"]) + self.app.add_api_route("/update_weights", + self.update_weights, + methods=["POST"]) if self.llm.args.return_perf_metrics: # register /prometheus/metrics self.mount_metrics() @@ -298,6 +311,16 @@ def register_mm_encoder_routes(self): self.app.add_api_route("/v1/chat/completions", self.openai_mm_encoder, methods=["POST"]) + # RL-only endpoints + self.app.add_api_route("/release_memory", + self.release_memory, + methods=["POST"]) + self.app.add_api_route("/resume_memory", + self.resume_memory, + methods=["POST"]) + self.app.add_api_route("/update_weights", + self.update_weights, + methods=["POST"]) async def health(self) -> Response: if self._check_health(): @@ -990,6 +1013,20 @@ async def create_stream_response(generator, request: ResponsesRequest, sampling_ return JSONResponse(content={"detail": "None"}) + async def release_memory(self, request: MemoryUpdateRequest) -> JSONResponse: + assert isinstance(self.llm, AsyncLLM), "/release_memory endpoint is only supported with AsyncLLM()" + await self.llm.collective_rpc('sleep', args=(request.tags,)) + return JSONResponse(content={"status": "success"}) + + async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse: + assert isinstance(self.llm, AsyncLLM), "/resume_memory endpoint is only supported with AsyncLLM()" + await self.llm.collective_rpc('wakeup', args=(request.tags,)) + return JSONResponse(content={"status": "success"}) + + async def update_weights(self, request: UpdateWeightsRequest) -> JSONResponse: + assert isinstance(self.llm, AsyncLLM), "/update_weights endpoint is only supported with AsyncLLM()" + await self.llm.collective_rpc('update_weights', args=(request.weights,)) + return JSONResponse(content={"status": "success"}) async def __call__(self, host, port, sockets: list[socket.socket] | None = None): # Store the binding address for server registration diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 0eca7d48474..cbe6807381a 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -276,8 +276,25 @@ l0_dgx_h100: tests: - unittest/_torch/ray_orchestrator/multi_gpu -m "gpu2" - unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu2" + - unittest/llmapi/test_async_llm.py -m "gpu2" - accuracy/test_llm_api_pytorch_ray.py::TestLlama3_1_8BInstruct::test_pp2_ray - examples/test_ray.py::test_llm_inference_distributed_ray[tp2] - examples/test_ray.py::test_llm_inference_distributed_ray[pp2] - examples/test_ray.py::test_llm_inference_distributed_ray[tep2] - examples/test_ray.py::test_ray_disaggregated_serving[tp1] +- condition: + ranges: + system_gpu_count: + gte: 4 + lte: 4 + wildcards: + gpu: + - '*h100*' + linux_distribution_name: ubuntu* + terms: + stage: pre_merge + backend: pytorch + orchestrator: ray + tests: + - unittest/_torch/ray_orchestrator/multi_gpu -m "gpu4" + - unittest/llmapi/test_async_llm.py -m "gpu4" diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 42859c06ec5..ecd52f05bcb 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -143,6 +143,7 @@ l0_h100: - unittest/_torch/executor - unittest/_torch/ray_orchestrator/single_gpu - unittest/llmapi/test_llm_pytorch.py + - unittest/llmapi/test_async_llm.py -m "not (gpu2 or gpu4)" - examples/test_ray.py::test_llm_inference_async_ray - condition: ranges: diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py index bea4f94d713..578be1f6dd8 100644 --- a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py @@ -9,27 +9,23 @@ from tensorrt_llm import LLM from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm.llmapi import KvCacheConfig +from tensorrt_llm.llmapi.llm_args import RayPlacementConfig -class DummyWorkerExtension: - - def additional_method(self): - return "SUCCESS" - - +@pytest.mark.gpu2 def test_worker_extension(): llm = LLM(model=llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0", - ray_worker_extension_cls="test_executor.DummyWorkerExtension", - orchestrator_type="ray") - result = llm._collective_rpc("additional_method") - assert result[0] == "SUCCESS" + ray_worker_extension_cls= + "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + orchestrator_type="ray", + tensor_parallel_size=2) + result = llm._collective_rpc("check_weights_updated") + assert isinstance(result[0], bool) @pytest.mark.gpu4 -def test_bundle_indices(monkeypatch): - """Placement via bundle indices""" - +def test_placement_env_vars(monkeypatch): monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") pg = None @@ -77,6 +73,52 @@ def test_bundle_indices(monkeypatch): ray.shutdown() +@pytest.mark.gpu2 +@pytest.mark.threadleak(enabled=False) +@pytest.mark.parametrize("n_gpus,bundle_indices", [ + (2, [1]), +], + ids=["gpu2_tp1"]) +def test_placement_api(monkeypatch, n_gpus, bundle_indices): + monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") + + tp_size = n_gpus // 2 + pg = None + try: + ray.init() + pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus) + ray.get(pg.ready()) + print(f"Placement group ready with bundles {pg.bundle_specs}") + + llm = LLM( + model=os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0"), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), + tensor_parallel_size=tp_size, + orchestrator_type="ray", + ray_placement_config=RayPlacementConfig( + placement_groups=[pg], + placement_bundle_indices=[bundle_indices], + per_worker_gpu_share=0.8, + ), + ) + + inference_actor_uuids = llm._collective_rpc("report_device_id") + expected_uuids = [get_device_uuid(idx) for idx in bundle_indices] + + print( + f"{inference_actor_uuids=}, all_uuids={[get_device_uuid(i) for i in range(n_gpus)]}" + ) + + assert sorted(inference_actor_uuids) == sorted(expected_uuids), \ + f"Workers not placed on expected GPUs. Expected: {expected_uuids}, Got: {inference_actor_uuids}" + + finally: + if pg is not None: + remove_placement_group(pg) + ray.shutdown() + + @pytest.mark.gpu2 def test_cuda_visible_device(monkeypatch): """Placement via cuda_visible_device""" diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 6d02fed3975..02b4500a1e5 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -207,6 +207,10 @@ methods: annotation: Optional[str] default: null status: prototype + ray_placement_config: + annotation: Optional[tensorrt_llm.llmapi.llm_args.RayPlacementConfig] + default: null + status: prototype enable_sleep: annotation: bool default: False diff --git a/tests/unittest/llmapi/test_async_llm.py b/tests/unittest/llmapi/test_async_llm.py new file mode 100644 index 00000000000..e0e7dd6d0ff --- /dev/null +++ b/tests/unittest/llmapi/test_async_llm.py @@ -0,0 +1,137 @@ +import os + +import pytest +import ray +from ray.util.placement_group import placement_group, remove_placement_group +from utils.llm_data import llm_models_root +from utils.util import get_current_process_gpu_memory + +from tensorrt_llm import AsyncLLM +from tensorrt_llm._torch.utils import get_device_uuid +from tensorrt_llm._torch.virtual_memory import ExecutorMemoryType +from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams + + +@pytest.mark.ray +@pytest.mark.asyncio +async def test_async_llm_awaitable(): + llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + kv_cache_config = KvCacheConfig(enable_block_reuse=False) + + prompt = "The future of AI is" + sampling_params = SamplingParams(temperature=0, max_tokens=12) + + llm = await AsyncLLM( + model=llama_model_path, + enable_sleep=True, + cuda_graph_config=None, + kv_cache_config=kv_cache_config, + ) + + output = await llm.generate_async(prompt, sampling_params) + assert output.outputs[0].text + print("Output text:", output.outputs[0].text) + + del llm + + +@pytest.mark.ray +@pytest.mark.gpu2 +@pytest.mark.asyncio +@pytest.mark.parametrize("num_cycles", [3], ids=lambda x: f"{x}_cycle") +async def test_async_llm_release_resume(process_gpu_memory_info_available, num_cycles): + llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=4096) + + prompt = "The future of AI is" + sampling_params = SamplingParams(temperature=0, max_tokens=12) + tags = [tag.value for tag in ExecutorMemoryType] + + async with AsyncLLM( + model=llama_model_path, + enable_sleep=True, + cuda_graph_config=None, + kv_cache_config=kv_cache_config, + tensor_parallel_size=2, + ) as llm: + # Generate baseline + output_before = await llm.generate_async(prompt, sampling_params) + baseline_text = output_before.outputs[0].text + + for cycle in range(num_cycles): + memory_usage_active = get_current_process_gpu_memory(True) / 1024**3 + print(f"[Cycle {cycle + 1}] Memory usage before release: {memory_usage_active:.2f} GB") + + await llm.release(tags) + memory_usage_released = get_current_process_gpu_memory(True) / 1024**3 + + if process_gpu_memory_info_available: + print( + f"[Cycle {cycle + 1}] Memory usage after release: {memory_usage_released:.2f} GB" + ) + assert memory_usage_released < memory_usage_active, ( + f"Released memory ({memory_usage_released:.2f} GB) should be < " + f"active memory ({memory_usage_active:.2f} GB)" + ) + + await llm.resume(tags) + memory_usage_resumed = get_current_process_gpu_memory(True) / 1024**3 + print(f"[Cycle {cycle + 1}] Memory usage after resume: {memory_usage_resumed:.2f} GB") + if process_gpu_memory_info_available: + assert memory_usage_resumed > memory_usage_released, ( + f"Resumed memory ({memory_usage_resumed:.2f} GB) should be > " + f"released memory ({memory_usage_released:.2f} GB)" + ) + + output_after = await llm.generate_async(prompt, sampling_params) + text_after = output_after.outputs[0].text + + print(f"[Cycle {num_cycles}] Generated text after release/resume: {text_after}") + assert baseline_text == text_after, ( + f"Generated text mismatch after {num_cycles} cycle(s): " + f"'{baseline_text}' != '{text_after}'" + ) + + +@pytest.mark.ray +@pytest.mark.gpu4 +@pytest.mark.asyncio +@pytest.mark.threadleak(enabled=False) +async def test_async_llm_placement_api(monkeypatch): + monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") + + n_gpus = 4 + bundle_indices = [2, 3] + tp_size = len(bundle_indices) + + pg = None + try: + ray.init() + pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus) + ray.get(pg.ready()) + print(f"Placement group ready with bundles {pg.bundle_specs}") + + llm = await AsyncLLM( + model=os.path.join( + str(llm_models_root()), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0" + ), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), + tensor_parallel_size=tp_size, + placement_groups=[pg], + placement_bundle_indices=[bundle_indices], + per_worker_gpu_share=0.8, + ) + + inference_actor_uuids = await llm.collective_rpc("report_device_id") + expected_uuids = [get_device_uuid(idx) for idx in bundle_indices] + + print(f"{inference_actor_uuids=}, all_uuids={[get_device_uuid(i) for i in range(n_gpus)]}") + + assert sorted(inference_actor_uuids) == sorted(expected_uuids), ( + f"Workers not placed on expected GPUs. Expected: {expected_uuids}, Got: {inference_actor_uuids}" + ) + + finally: + if pg is not None: + remove_placement_group(pg) + ray.shutdown()