Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions tensorrt_llm/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,19 @@ def __init__(self,
if self.use_rpc:
self.init_rpc_executor()
self.worker_kwargs['rpc_addr'] = self.rpc_addr
if not has_event_loop():
self.init_workers_sync()
self.setup_engine_remote()
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
logger.info(f"Connecting to RPC server at {self.rpc_addr}")

# Always use async initialization in RPC mode to avoid blocking ray.get()
logger.info(
"RPC mode detected - using async initialization (deferred worker & mainloop setup)"
)
self.workers = [
] # Placeholder, will be initialized in setup_async
self._needs_async_setup = True
self._mainloop_started = False # Track if mainloop has been started
# DO NOT start mainloop until after setup_engine_remote_async is called
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Superjomn wouldn't this break non RL-integ flow?

logger.info(
f"Will connect to RPC server at {self.rpc_addr} after async init"
)
else:
self.response_queue = RayAsyncQueue.options(runtime_env={
"env_vars": {
Expand Down Expand Up @@ -123,9 +130,13 @@ def create_workers(self, worker_cls, worker_kwargs):

# When set to be a fraction, it allows Ray to schedule
# multiple actors on a single GPU for colocate use cases.
num_gpus = (llm_args.per_worker_gpu_share if llm_args
and llm_args.per_worker_gpu_share is not None else float(
os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0")))
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
if llm_args:
if getattr(llm_args, 'placement_share', None) is not None:
num_gpus = llm_args.placement_share
elif llm_args.per_worker_gpu_share is not None:
num_gpus = llm_args.per_worker_gpu_share

logger.debug(f"{num_gpus=} for each worker.")

runtime_env = ray.runtime_env.RuntimeEnv()
Expand Down Expand Up @@ -266,6 +277,26 @@ def start(self):
def setup_engine_remote(self):
return self.collective_rpc("setup_engine", non_block=False)

async def setup_engine_remote_async(self):
"""Async version of setup_engine_remote for use after async worker initialization."""
if not self.workers or len(self.workers) == 0:
raise RuntimeError(
"Workers must be initialized before calling setup_engine_remote_async"
)

# Setup engine on all workers
result = await self.collective_rpc_async("setup_engine")
logger.info("setup_engine_remote_async finished")

# Now that engine is set up, start the mainloop for fetching responses
if hasattr(self, '_mainloop_started') and not self._mainloop_started:
logger.info("Starting mainloop after engine setup")
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
self._mainloop_started = True

return result

def report_device_ids(self) -> list[str]:
gpu_ids = self.call_all_ray_workers("report_device_id",
leader_only=False,
Expand Down Expand Up @@ -371,6 +402,27 @@ def _get_placement_group(
"""
llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None

if llm_args and getattr(llm_args, 'placement_where', None) is not None:
total_workers = sum(
len(indices) for _, indices in llm_args.placement_where)
if total_workers != self.world_size:
raise ValueError(
f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})"
)

logger.info(
f"Creating {self.world_size} workers with external placement_where"
)

flat_pgs = []
flat_indices = []
for pg, indices in llm_args.placement_where:
for idx in indices:
flat_pgs.append(pg)
flat_indices.append(idx)

return flat_pgs, flat_indices

if llm_args and hasattr(
llm_args,
'placement_groups') and llm_args.placement_groups is not None:
Expand Down
16 changes: 13 additions & 3 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from dataclasses import dataclass
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
Set, Tuple, Type, TypeAlias, TypeVar, Union, get_args,
get_origin)
from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple,
Type, TypeAlias, TypeVar, Union, get_args, get_origin)

import torch
import yaml
Expand Down Expand Up @@ -2730,6 +2729,17 @@ class TorchLlmArgs(BaseLlmArgs):
"Example: 0.1 means 10 actors can share one GPU. Defaults to 1.0 (one actor per GPU).",
status="prototype")

placement_where: Optional[List[Tuple[Any, List[int]]]] = Field(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know these two knobs are duplicates, just put here to avoid modification on the verl side. We may clean-up the knobs finally.

default=None,
description="List of (PlacementGroup, List[int]) tuples for each node.",
status="prototype")

placement_share: Optional[float] = Field(
default=None,
description="GPU fraction per worker for colocation scenarios. "
"Example: 0.1 means 10 actors can share one GPU. Defaults to 1.0 (one actor per GPU).",
status="prototype")

enable_sleep: bool = Field(
default=False,
description=
Expand Down
14 changes: 11 additions & 3 deletions tensorrt_llm/llmapi/rlhf_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import pickle
from typing import Optional

import torch
Expand All @@ -6,8 +8,6 @@
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.logger import logger
import pickle
import base64


class WorkerExtension:
Expand Down Expand Up @@ -58,7 +58,15 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")

weights = {}
all_handles = ipc_handles[device_uuid]
# Deserialize the base64-encoded pickled data
serialized_handles = ipc_handles[device_uuid]
if isinstance(serialized_handles, str):
# Data is base64-encoded pickled bytes - deserialize it
logger.info("Deserializing base64-encoded weight handles")
all_handles = pickle.loads(base64.b64decode(serialized_handles))
else:
# Data is already in the correct format (backward compatibility)
all_handles = serialized_handles

for param_name, tensor_handle in all_handles:
func, args = tensor_handle
Expand Down
Loading