Skip to content

Commit 6e9b967

Browse files
committed
update placement api
1 parent 54e3b37 commit 6e9b967

File tree

5 files changed

+181
-62
lines changed

5 files changed

+181
-62
lines changed

tensorrt_llm/executor/executor.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
3838
from .result import GenerationResult, IterationResult
3939
from .utils import IntraProcessQueue, ProcessPoolExecutorSession, RequestError
40-
from ray.util.placement_group import PlacementGroup
4140

4241
if TYPE_CHECKING:
4342
from .proxy import GenerationExecutorProxy
@@ -368,8 +367,6 @@ def _create_ray_executor(
368367
postproc_worker_config: PostprocWorkerConfig,
369368
is_llm_executor: bool,
370369
tp_size: int,
371-
placement_share: float = 1.0,
372-
placement_where: list[tuple[PlacementGroup, list[int]]] = None,
373370
):
374371
logger.warning(f"Orchestrator is creating Ray executor")
375372
from .ray_executor import RayExecutor
@@ -378,9 +375,7 @@ def _create_ray_executor(
378375
model_world_size=model_world_size,
379376
postproc_worker_config=postproc_worker_config,
380377
is_llm_executor=is_llm_executor,
381-
tp_size=tp_size,
382-
placement_share=placement_share,
383-
placement_where=placement_where)
378+
tp_size=tp_size)
384379

385380
@staticmethod
386381
def _create_rpc_executor(
@@ -444,8 +439,6 @@ def create(
444439
hf_model_dir: Optional[Path] = None,
445440
tokenizer: Optional[TokenizerBase] = None,
446441
llm_args: Optional[BaseLlmArgs] = None,
447-
placement_share: float = 1.0,
448-
placement_where: list[tuple[PlacementGroup, list[int]]] = None,
449442
**args,
450443
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
451444
if world_size == 0:
@@ -485,9 +478,7 @@ def create(
485478
model_world_size,
486479
postproc_worker_config,
487480
is_llm_executor=is_llm_executor,
488-
tp_size=args.get("tp_size", 1),
489-
placement_share=placement_share,
490-
placement_where=placement_where)
481+
tp_size=args.get("tp_size", 1))
491482
elif orchestrator_type is not None and orchestrator_type != "rpc":
492483
raise ValueError(
493484
f"Unsupported orchestrator_type: {orchestrator_type}")

tensorrt_llm/executor/ray_executor.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
99
raise
1010

11-
from ray.util.placement_group import (PlacementGroup,
12-
PlacementGroupSchedulingStrategy,
11+
from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
1312
get_current_placement_group,
1413
placement_group)
1514

@@ -38,18 +37,13 @@ def __init__(self,
3837
model_world_size: int,
3938
postproc_worker_config: PostprocWorkerConfig,
4039
is_llm_executor: bool,
41-
tp_size=1,
42-
placement_share: float = 1.0,
43-
placement_where: list[tuple[PlacementGroup, list[int]]] = None):
40+
tp_size=1):
4441
os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1'
4542
os.environ["RAY_DEDUP_LOGS"] = "0" # for debug
4643

4744
super().__init__(model_world_size, postproc_worker_config,
4845
is_llm_executor)
4946

50-
self.placement_share = placement_share
51-
self.placement_where = placement_where
52-
5347
self.has_start_local_cluser = False
5448
runtime_env = {
5549
"env_vars": {
@@ -125,9 +119,13 @@ def __init__(self,
125119
raise e
126120

127121
def create_workers(self, worker_cls, worker_kwargs):
122+
llm_args = worker_kwargs.get("llm_args")
123+
128124
# When set to be a fraction, it allows Ray to schedule
129125
# multiple actors on a single GPU for colocate use cases.
130-
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
126+
num_gpus = (llm_args.per_worker_gpu_share if llm_args
127+
and llm_args.per_worker_gpu_share is not None else float(
128+
os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0")))
131129
logger.debug(f"{num_gpus=} for each worker.")
132130

133131
runtime_env = ray.runtime_env.RuntimeEnv()
@@ -138,21 +136,26 @@ def create_workers(self, worker_cls, worker_kwargs):
138136
"MASTER_PORT": str(self.master_port)
139137
})
140138

141-
rank = 0
142-
self.world_size = sum(len(bundle_indices) for _, bundle_indices in self.placement_where)
139+
placement_groups, self.bundle_indices = self._get_placement_group(
140+
tp_size=self.tp_size, worker_kwargs=worker_kwargs)
141+
142+
if isinstance(placement_groups, list):
143+
self.placement_group = None
144+
else:
145+
self.placement_group = placement_groups
146+
143147
self.workers = []
144-
for pg, bundle_indices in self.placement_where:
145-
for bundle_index in bundle_indices:
146-
self.workers.append(
147-
RayWorkerWrapper.options(
148-
num_gpus=self.placement_share,
149-
runtime_env=runtime_env, # per-actor env
150-
scheduling_strategy=PlacementGroupSchedulingStrategy(
151-
placement_group=pg,
152-
placement_group_bundle_index=bundle_index,
153-
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
154-
)
155-
rank += 1
148+
for rank in range(self.world_size):
149+
pg = placement_groups[rank] if isinstance(
150+
placement_groups, list) else placement_groups
151+
worker = RayWorkerWrapper.options(
152+
num_gpus=num_gpus,
153+
runtime_env=runtime_env,
154+
scheduling_strategy=PlacementGroupSchedulingStrategy(
155+
placement_group=pg,
156+
placement_group_bundle_index=self.bundle_indices[rank],
157+
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
158+
self.workers.append(worker)
156159

157160
def init_workers_sync(self):
158161
self.create_workers(RayGPUWorker, self.worker_kwargs)
@@ -336,15 +339,48 @@ def shutdown(self):
336339
def _get_worker_ready_futures(self):
337340
return [worker.__ray_ready__.remote() for worker in self.workers]
338341

339-
def _get_placement_group(self,
340-
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
342+
def _get_placement_group(
343+
self,
344+
tp_size: int,
345+
worker_kwargs: Dict = None) -> Tuple[Any, List[int]]:
341346
"""
342347
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
343348
or create a default PACK placement group where each bundle has tp_size GPUs.
344349
- When tp_size ≤ GPUs per node, keep one TP group per node.
345350
- When tp_size > GPUs per node, allow a TP group span nodes.
346351
- rank 0 must be put on the driver node
352+
353+
Returns:
354+
Tuple of (placement_group(s), bundle_indices)
355+
- placement_group(s) can be a single PlacementGroup or a List[PlacementGroup]
356+
- bundle_indices is always a List[int]
347357
"""
358+
llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None
359+
360+
if llm_args and hasattr(
361+
llm_args,
362+
'placement_groups') and llm_args.placement_groups is not None:
363+
total_workers = sum(
364+
len(indices) for indices in llm_args.placement_bundle_indices)
365+
if total_workers != self.world_size:
366+
raise ValueError(
367+
f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})"
368+
)
369+
370+
logger.info(
371+
f"Creating {self.world_size} workers with external placement groups"
372+
)
373+
374+
flat_pgs = []
375+
flat_indices = []
376+
for pg, indices in zip(llm_args.placement_groups,
377+
llm_args.placement_bundle_indices):
378+
for idx in indices:
379+
flat_pgs.append(pg)
380+
flat_indices.append(idx)
381+
382+
return flat_pgs, flat_indices
383+
348384
bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)
349385

350386
if bundle_indices:

tensorrt_llm/llmapi/llm.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,12 @@ def __init__(self,
126126
dtype: str = "auto",
127127
revision: Optional[str] = None,
128128
tokenizer_revision: Optional[str] = None,
129-
placement_share: float = 1.0,
130-
placement_where: list[tuple[PlacementGroup, list[int]]] = None,
131129
**kwargs: Any) -> None:
132130

133131
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
134132
self._orchestrator_type = kwargs.get("orchestrator_type", None)
135133
self._llm_id = None
136134

137-
self.placement_share = placement_share
138-
self.placement_where = placement_where
139-
140135
log_level = logger.level
141136
logger.set_level("info") # force display the backend
142137

@@ -814,14 +809,12 @@ def __init__(self,
814809
dtype: str = "auto",
815810
revision: Optional[str] = None,
816811
tokenizer_revision: Optional[str] = None,
817-
placement_share: float = 1.0,
818-
placement_where: list[tuple[PlacementGroup, list[int]]] = None,
819812
**kwargs: Any) -> None:
820813
# TODO: deprecate backend in LLM kwargs
821814

822815
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
823816
trust_remote_code, tensor_parallel_size, dtype,
824-
revision, tokenizer_revision, placement_share, placement_where, **kwargs)
817+
revision, tokenizer_revision, **kwargs)
825818

826819
@property
827820
def workspace(self) -> Path:
@@ -979,9 +972,7 @@ def _build_model(self):
979972
num_postprocess_workers=self.args.num_postprocess_workers,
980973
postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,
981974
),
982-
is_llm_executor=True,
983-
placement_share=self.placement_share,
984-
placement_where=self.placement_where)
975+
is_llm_executor=True)
985976

986977

987978
@append_docstring(TORCH_LLM_DOCSTRING)
@@ -1002,8 +993,6 @@ def __init__(self,
1002993
dtype: str = "auto",
1003994
revision: Optional[str] = None,
1004995
tokenizer_revision: Optional[str] = None,
1005-
placement_share: float = 1.0,
1006-
placement_where: list[tuple[PlacementGroup, list[int]]] = None,
1007996
**kwargs: Any) -> None:
1008997

1009998
# TODO: deprecate backend in LLM kwargs
@@ -1022,8 +1011,6 @@ def __init__(self,
10221011
revision,
10231012
tokenizer_revision,
10241013
backend=backend,
1025-
placement_share=placement_share,
1026-
placement_where=placement_where,
10271014
**kwargs)
10281015

10291016
@set_api_status("prototype")
@@ -1091,9 +1078,7 @@ def _build_model(self):
10911078
is_llm_executor=True,
10921079
hf_model_dir=self._hf_model_dir,
10931080
tokenizer=self.tokenizer,
1094-
llm_args=self.args,
1095-
placement_share=self.placement_share,
1096-
placement_where=self.placement_where)
1081+
llm_args=self.args)
10971082

10981083
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
10991084
"""Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend.
@@ -1129,12 +1114,10 @@ def __init__(self,
11291114
dtype: str = "auto",
11301115
revision: Optional[str] = None,
11311116
tokenizer_revision: Optional[str] = None,
1132-
placement_share: float = 1.0,
1133-
placement_where: list[tuple[PlacementGroup, list[int]]] = None,
11341117
**kwargs: Any) -> None:
11351118
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
11361119
trust_remote_code, tensor_parallel_size, dtype,
1137-
revision, tokenizer_revision, placement_share, placement_where, **kwargs)
1120+
revision, tokenizer_revision, **kwargs)
11381121

11391122

11401123
# sphinx will ignore the LLM's docstring if it is not explicitly set

tensorrt_llm/llmapi/llm_args.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from dataclasses import dataclass
99
from enum import Enum, EnumMeta
1010
from pathlib import Path
11-
from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple,
12-
Type, TypeAlias, TypeVar, Union, get_args, get_origin)
11+
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
12+
Set, Tuple, Type, TypeAlias, TypeVar, Union, get_args,
13+
get_origin)
1314

1415
import torch
1516
import yaml
@@ -19,6 +20,11 @@
1920
from strenum import StrEnum
2021
from transformers import PreTrainedTokenizerBase
2122

23+
try:
24+
from ray.util.placement_group import PlacementGroup
25+
except ImportError:
26+
PlacementGroup = None
27+
2228
from tensorrt_llm.lora_helper import (LoraConfig,
2329
get_default_trtllm_modules_to_hf_modules)
2430

@@ -2695,6 +2701,26 @@ class TorchLlmArgs(BaseLlmArgs):
26952701
"Allows users to extend the functions of the RayGPUWorker class.",
26962702
status="prototype")
26972703

2704+
# Ray placement group config. Namings TBD.
2705+
placement_groups: Optional[List[Any]] = Field(
2706+
default=None,
2707+
description="List of Ray placement groups, one per node. "
2708+
"Each element must be a ray.util.placement_group.PlacementGroup instance.",
2709+
exclude_from_json=True,
2710+
status="prototype")
2711+
2712+
placement_bundle_indices: Optional[List[List[int]]] = Field(
2713+
default=None,
2714+
description="List of bundle indices for each placement group. "
2715+
"Outer list corresponds to placement_groups, inner list contains bundle indices for that group. ",
2716+
status="prototype")
2717+
2718+
per_worker_gpu_share: Optional[float] = Field(
2719+
default=None,
2720+
description="GPU fraction per worker for colocation scenarios. "
2721+
"Example: 0.1 means 10 actors can share one GPU. Defaults to 1.0 (one actor per GPU).",
2722+
status="prototype")
2723+
26982724
enable_sleep: bool = Field(
26992725
default=False,
27002726
description=
@@ -3000,6 +3026,44 @@ def validate_ray_worker_extension_cls(self) -> 'TorchLlmArgs':
30003026
)
30013027
return self
30023028

3029+
@model_validator(mode='after')
3030+
def validate_ray_placement_config(self) -> 'TorchLlmArgs':
3031+
has_pgs = self.placement_groups is not None
3032+
has_indices = self.placement_bundle_indices is not None
3033+
3034+
if (has_pgs or has_indices) and self.orchestrator_type != "ray":
3035+
raise ValueError(
3036+
"placement_groups is only supported with orchestrator_type='ray'"
3037+
)
3038+
3039+
if has_pgs != has_indices:
3040+
raise ValueError(
3041+
"placement_groups and placement_bundle_indices must be provided together"
3042+
)
3043+
3044+
if has_pgs:
3045+
if len(self.placement_groups) != len(self.placement_bundle_indices):
3046+
raise ValueError(
3047+
f"placement_groups length ({len(self.placement_groups)}) must equal "
3048+
f"placement_bundle_indices length ({len(self.placement_bundle_indices)})"
3049+
)
3050+
3051+
if self.per_worker_gpu_share is not None:
3052+
if not (0 < self.per_worker_gpu_share <= 1.0):
3053+
raise ValueError(
3054+
f"per_worker_gpu_share must be between 0 and 1.0, "
3055+
f"got {self.per_worker_gpu_share}")
3056+
3057+
if has_pgs:
3058+
if PlacementGroup is not None:
3059+
for i, pg in enumerate(self.placement_groups):
3060+
if not isinstance(pg, PlacementGroup):
3061+
raise TypeError(
3062+
f"placement_groups[{i}] must be a Ray PlacementGroup, "
3063+
f"got {type(pg).__name__}")
3064+
3065+
return self
3066+
30033067
def get_executor_config(
30043068
self,
30053069
_hf_model_dir: Optional[Path] = None,

0 commit comments

Comments
 (0)