Skip to content

Commit 23d5280

Browse files
authored
[TRTLLM-7843][feat] implement disagg cluster auto-scaling (NVIDIA#8215)
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 9b54b3b commit 23d5280

File tree

11 files changed

+682
-59
lines changed

11 files changed

+682
-59
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ nvtx
6767
matplotlib # FIXME: this is added to make nvtx happy
6868
meson
6969
ninja
70-
etcd3
70+
etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225b61f039457c8072a
7171
blake3
7272
soundfile
7373
triton==3.3.1; platform_machine == "x86_64"

tensorrt_llm/commands/serve.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
2222
DynamicBatchConfig, KvCacheConfig,
2323
SchedulerConfig)
24-
from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig, ServerRole,
24+
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
25+
MetadataServerConfig, ServerRole,
26+
extract_disagg_cluster_config,
2527
parse_disagg_config_file,
2628
parse_metadata_server_config_file)
2729
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
@@ -140,7 +142,8 @@ def launch_server(host: str,
140142
port: int,
141143
llm_args: dict,
142144
metadata_server_cfg: Optional[MetadataServerConfig] = None,
143-
server_role: Optional[ServerRole] = None):
145+
server_role: Optional[ServerRole] = None,
146+
disagg_cluster_config: Optional[DisaggClusterConfig] = None):
144147

145148
backend = llm_args["backend"]
146149
model = llm_args["model"]
@@ -161,7 +164,8 @@ def launch_server(host: str,
161164
server = OpenAIServer(llm=llm,
162165
model=model,
163166
server_role=server_role,
164-
metadata_server_cfg=metadata_server_cfg)
167+
metadata_server_cfg=metadata_server_cfg,
168+
disagg_cluster_config=disagg_cluster_config)
165169

166170
# Optionally disable GC (default: not disabled)
167171
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
@@ -313,6 +317,10 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
313317
help=
314318
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
315319
)
320+
@click.option("--disagg_cluster_uri",
321+
type=str,
322+
default=None,
323+
help="URI of the disaggregated cluster.")
316324
@click.option("--enable_chunked_prefill",
317325
is_flag=True,
318326
default=False,
@@ -327,7 +335,7 @@ def serve(
327335
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
328336
metadata_server_config_file: Optional[str], server_role: Optional[str],
329337
fail_fast_on_attention_window_too_large: bool,
330-
enable_chunked_prefill: bool):
338+
enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str]):
331339
"""Running an OpenAI API compatible server
332340
333341
MODEL: model name | HF checkpoint path | TensorRT engine path
@@ -364,14 +372,27 @@ def serve(
364372
metadata_server_cfg = parse_metadata_server_config_file(
365373
metadata_server_config_file)
366374

367-
if metadata_server_cfg is not None:
368-
assert server_role is not None, "server_role is required when metadata_server_cfg is provided"
375+
# Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
376+
# but disagg_cluster_uri takes precedence over cluster uri in config file
377+
disagg_cluster_config = llm_args.pop("disagg_cluster", None)
378+
if disagg_cluster_config:
379+
disagg_cluster_config = extract_disagg_cluster_config(
380+
disagg_cluster_config, disagg_cluster_uri)
381+
elif disagg_cluster_uri:
382+
disagg_cluster_config = DisaggClusterConfig(
383+
cluster_uri=disagg_cluster_uri)
384+
385+
if metadata_server_cfg is not None or disagg_cluster_config is not None:
386+
assert (
387+
server_role is not None
388+
), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
369389
try:
370390
server_role = ServerRole[server_role.upper()]
371391
except ValueError:
372392
raise ValueError(f"Invalid server role: {server_role}. " \
373393
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
374-
launch_server(host, port, llm_args, metadata_server_cfg, server_role)
394+
launch_server(host, port, llm_args, metadata_server_cfg, server_role,
395+
disagg_cluster_config)
375396

376397

377398
@click.command("mm_embedding_serve")

tensorrt_llm/llmapi/disagg_utils.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from dataclasses import dataclass, field
33
from enum import IntEnum
4-
from typing import Any, List, Literal, Optional, Tuple
4+
from typing import Any, Dict, List, Literal, Optional, Tuple
55

66
import yaml
77
from mpi4py.MPI import COMM_WORLD, Comm
@@ -68,6 +68,7 @@ class DisaggServerConfig():
6868
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
6969
max_retries: int = 1
7070
perf_metrics_max_requests: int = 0
71+
disagg_cluster_config: Optional[DisaggClusterConfig] = None
7172

7273

7374
@dataclass
@@ -111,6 +112,7 @@ def extract_disagg_cfg(hostname: str = 'localhost',
111112
context_servers: Optional[dict] = None,
112113
generation_servers: Optional[dict] = None,
113114
conditional_disagg_config: Optional[dict] = None,
115+
disagg_cluster: Optional[dict] = None,
114116
**kwargs: Any) -> DisaggServerConfig:
115117
context_servers = context_servers or {}
116118
generation_servers = generation_servers or {}
@@ -131,23 +133,27 @@ def extract_disagg_cfg(hostname: str = 'localhost',
131133
# Inherit the value from the top-level
132134
servers[key] = value
133135

136+
server_configs = []
137+
disagg_cluster_config = None
134138
ctx_router_config = extract_router_config(context_servers)
135139
gen_router_config = extract_router_config(generation_servers)
136-
137-
server_configs = extract_ctx_gen_cfgs(
138-
type="ctx", **context_servers) + extract_ctx_gen_cfgs(
139-
type="gen", **generation_servers)
140-
141140
ctx_router_config.server_role = ServerRole.CONTEXT
142141
gen_router_config.server_role = ServerRole.GENERATION
142+
if disagg_cluster:
143+
disagg_cluster_config = extract_disagg_cluster_config(disagg_cluster)
144+
else:
145+
server_configs = extract_ctx_gen_cfgs(
146+
type="ctx", **context_servers) + extract_ctx_gen_cfgs(
147+
type="gen", **generation_servers)
143148

144149
conditional_disagg_config = ConditionalDisaggConfig(
145150
**conditional_disagg_config) if conditional_disagg_config else None
146151

147152
config = DisaggServerConfig(server_configs, hostname, port,
148153
ctx_router_config, gen_router_config,
149154
conditional_disagg_config, max_retries,
150-
perf_metrics_max_requests)
155+
perf_metrics_max_requests,
156+
disagg_cluster_config)
151157

152158
return config
153159

@@ -235,6 +241,33 @@ def get_server_configs_dict(
235241
return num_workers, server_dict
236242

237243

244+
def extract_disagg_cluster_config(
245+
cluster_config_dict: Dict[str, Any],
246+
cluster_uri: Optional[str] = None) -> DisaggClusterConfig:
247+
"""
248+
Build the DisaggClusterConfig from the cluster_config_dict.
249+
Use the default value of DisaggClusterConfig and MinimalInstances if the corresponding fields are not provided.
250+
If cluster_uri is provided, it will override the cluster_uri in the cluster_config_dict.
251+
"""
252+
253+
def update_dataclass(obj, data_dict: Dict[str, Any]):
254+
for key, value in data_dict.items():
255+
if key not in obj.__dataclass_fields__:
256+
raise KeyError(
257+
f"Key {key} not found in {obj.__class__.__name__}")
258+
if value is not None:
259+
setattr(obj, key, value)
260+
return obj
261+
262+
cluster_config_dict["minimal_instances"] = update_dataclass(
263+
MinimalInstances(), cluster_config_dict.get("minimal_instances", {}))
264+
cluster_config = update_dataclass(
265+
DisaggClusterConfig(cluster_uri or cluster_config_dict["cluster_uri"]),
266+
cluster_config_dict,
267+
)
268+
return cluster_config
269+
270+
238271
def split_world_comm(
239272
server_configs: List[CtxGenServerConfig]) -> Tuple[bool, int, Comm]:
240273

tensorrt_llm/serve/cluster_storage.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,10 @@ async def _check_expired(self):
281281
self._storage.pop(k)
282282
for k, v in kv_to_delete.items():
283283
await self._notify_watch_event(k, v, WatchEventType.DELETE)
284-
logger.debug(
285-
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
286-
)
284+
if len(kv_to_delete) > 0:
285+
logger.debug(
286+
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
287+
)
287288
except Exception as e:
288289
logger.error(f"Error checking expired: {e}")
289290

@@ -298,9 +299,12 @@ def __init__(self, cluster_uri, cluster_name):
298299
self._cluster_name = cluster_name
299300

300301
def __del__(self):
301-
if asyncio.get_event_loop():
302-
asyncio.run_coroutine_threadsafe(self._session.close(),
303-
asyncio.get_event_loop())
302+
try:
303+
if asyncio.get_event_loop():
304+
asyncio.run_coroutine_threadsafe(self._session.close(),
305+
asyncio.get_event_loop())
306+
except RuntimeError:
307+
pass
304308

305309
def _url_for(self, endpoint: str) -> str:
306310
return f"{self._cluster_uri}/{endpoint}"

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@ def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage):
4646
self._watch_handle = None
4747

4848
def __del__(self):
49-
if asyncio.get_event_loop():
50-
asyncio.run_coroutine_threadsafe(self.stop(),
51-
asyncio.get_event_loop())
49+
try:
50+
if asyncio.get_event_loop():
51+
asyncio.run_coroutine_threadsafe(self.stop(),
52+
asyncio.get_event_loop())
53+
except RuntimeError:
54+
# the event loop may not be running when the cluster manager is destroyed
55+
pass
5256

5357
async def start(self) -> None:
5458
await self._cluster_storage.start()
@@ -208,9 +212,13 @@ def __init__(self, role: ServerRole, host: str, port: int,
208212
self._worker_id = f"{role.name}-{host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
209213

210214
def __del__(self):
211-
if asyncio.get_event_loop():
212-
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
213-
asyncio.get_event_loop())
215+
try:
216+
if asyncio.get_event_loop():
217+
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
218+
asyncio.get_event_loop())
219+
except RuntimeError:
220+
# the event loop may not be running when the worker is destroyed
221+
pass
214222

215223
@property
216224
def worker_id(self) -> str:

0 commit comments

Comments
 (0)