Skip to content

Commit bd13957

Browse files
authored
[TRTLLM-9181][feat] improve disagg-server prometheus metrics; synchronize workers' clocks when workers are dynamic (#9726)
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 609d1d0 commit bd13957

25 files changed

+987
-324
lines changed

tensorrt_llm/_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,20 @@ def dim_resolve_negative(dim, ndim):
473473
return tuple(pos)
474474

475475

476-
def get_free_port():
477-
with socket.socket() as sock:
478-
sock.bind(("", 0))
479-
return sock.getsockname()[1]
476+
def get_free_port() -> int:
477+
return get_free_ports(1)[0]
478+
479+
480+
def get_free_ports(num=1) -> List[int]:
481+
sockets = [
482+
socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(num)
483+
]
484+
for s in sockets:
485+
s.bind(('', 0))
486+
ports = [s.getsockname()[1] for s in sockets]
487+
for s in sockets:
488+
s.close()
489+
return ports
480490

481491

482492
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import random
5+
import socket
56
import time
67
from dataclasses import asdict, dataclass
78
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
@@ -29,6 +30,18 @@ def get_worker_key(name: str, role: ServerRole, worker_id: str = "") -> str:
2930
return f"{get_worker_key_prefix(name)}/{worker_id}"
3031

3132

33+
def get_host_from_uri(uri: str) -> str:
34+
return uri.split("://")[1].split(":")[0]
35+
36+
37+
# Get the local ip address from a remote host,
38+
# if remote host is not provided, use Google's public DNS server "8.8.8.8"
39+
def get_local_ip(remote_host: str = "8.8.8.8") -> str:
40+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
41+
s.connect((remote_host, 80))
42+
return s.getsockname()[0]
43+
44+
3245
class DisaggClusterManager:
3346
"""
3447
The cluster manager is responsible for managing the workers in the cluster.
@@ -238,18 +251,25 @@ class DisaggClusterWorker:
238251
It will send heartbeat to the cluster storage every heartbeat_interval_sec seconds.
239252
If the worker heartbeat fails, it will re-register itself.
240253
"""
254+
LOCALHOST_IPS = ["localhost", "127.0.0.1", "0.0.0.0", "::1",
255+
"::"] # nosec B104
241256

242257
def __init__(self, role: ServerRole, host: str, port: int,
243258
config: DisaggClusterConfig, storage: ClusterStorage):
244259
self._role = role
245-
self._host = host
246260
self._port = port
247261
self._config = config
248262
self._cluster_storage = storage
249263
self._stop = False
250264
self._heartbeat_task = None
251265
self._last_heartbeat = 0
252-
self._worker_id = f"{role.name}-{host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
266+
register_host = host
267+
# if the host is localhost and the cluster uri is not localhost, use the hostname to register the worker
268+
disagg_host = get_host_from_uri(self._config.cluster_uri)
269+
if host in self.LOCALHOST_IPS and disagg_host not in self.LOCALHOST_IPS:
270+
register_host = get_local_ip(disagg_host)
271+
self._host = register_host
272+
self._worker_id = f"{role.name}-{register_host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
253273

254274
def __del__(self):
255275
try:

tensorrt_llm/serve/openai_client.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ async def _post_with_retry(
183183
yield response_dict
184184
# finish the request after the successful response
185185
await self._finish_request(request)
186+
self._metrics_collector.complete_latency_seconds.observe(
187+
get_steady_clock_now_in_seconds() - start_time
188+
)
186189
break # break and skip retries if the whole response is processed without exception
187190
except (aiohttp.ClientError, OSError) as e:
188191
if lines_yielded > 0:
@@ -227,25 +230,24 @@ async def _response_generator(
227230
i = 0
228231
async for line in http_response.content.iter_any():
229232
now_time = get_steady_clock_now_in_seconds()
230-
if i == 0:
231-
if hooks:
232-
hooks.on_first_token(server, request)
233-
self._metrics_collector.first_token_latency_seconds.observe(
234-
now_time - last_token_time
235-
)
236-
else:
237-
self._metrics_collector.per_token_latency_seconds.observe(
238-
now_time - last_token_time
239-
)
240-
i += 1
241233
if line:
234+
if i == 0:
235+
if hooks:
236+
hooks.on_first_token(server, request)
237+
self._metrics_collector.first_token_latency_seconds.observe(
238+
now_time - last_token_time
239+
)
240+
else:
241+
self._metrics_collector.per_token_latency_seconds.observe(
242+
now_time - last_token_time
243+
)
244+
i += 1
242245
yield line
243246
await asyncio.sleep(0)
244247
last_token_time = now_time
245248

246249
if hooks:
247250
hooks.on_resp_done(server, request, None)
248-
self._metrics_collector.completed_requests.inc()
249251
self._metrics_collector.complete_latency_seconds.observe(
250252
get_steady_clock_now_in_seconds() - start_time
251253
)
@@ -262,6 +264,7 @@ async def _response_generator(
262264
await self._finish_request(request)
263265

264266
async def _finish_request(self, request: UCompletionRequest) -> None:
267+
self._metrics_collector.completed_requests.inc()
265268
await self._router.finish_request(request)
266269

267270
async def collect_metrics(self) -> Dict[str, Any]:

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,12 @@ def __init__(self, raw_req: Request, perf_metrics_collector: DisaggPerfMetricsCo
5757
self.raw_req = raw_req
5858
self.ctx_server = ""
5959
self.gen_server = ""
60+
self.request_arrival_time = raw_req.state.server_arrival_time
6061
self.server_first_token_time = 0
6162
self.perf_metrics_collector = perf_metrics_collector
6263

6364
def on_req_begin(self, request: UCompletionRequest):
64-
...
65+
self.perf_metrics_collector.queue_latency_seconds.observe(get_steady_clock_now_in_seconds() - self.request_arrival_time)
6566

6667
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
6768
self.ctx_server = ctx_server
@@ -93,8 +94,8 @@ def __init__(self,
9394
self._metrics_interval_secs = metrics_interval_secs
9495

9596
self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs)
96-
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
97-
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
97+
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock)
98+
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock)
9899
self._metadata_server = create_metadata_server(metadata_server_cfg)
99100
self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests)
100101

@@ -122,8 +123,10 @@ def __init__(self,
122123

123124
@asynccontextmanager
124125
async def lifespan(app) -> None:
126+
# Prepare servers (sync server clock) when static ctx/gen server list is used
127+
await self._ctx_router.prepare_servers()
128+
await self._gen_router.prepare_servers()
125129
await self._service.setup()
126-
await self._set_steady_clock_offsets()
127130
yield
128131
await self._service.teardown()
129132

@@ -133,6 +136,7 @@ async def lifespan(app) -> None:
133136

134137
@self.app.exception_handler(RequestValidationError)
135138
async def validation_exception_handler(_, exc):
139+
self._perf_metrics_collector.validation_exceptions.inc()
136140
return JSONResponse(status_code=400, content={"error": str(exc)})
137141

138142
self.register_routes()
@@ -158,8 +162,14 @@ def register_routes(self):
158162
def _wrap_entry_point(self, entry_point: Callable) -> Callable:
159163
async def wrapper(req: UCompletionRequest, raw_req: Request) -> Response:
160164
try:
165+
self._perf_metrics_collector.total_requests.inc()
166+
if req.stream:
167+
self._perf_metrics_collector.stream_requests.inc()
168+
else:
169+
self._perf_metrics_collector.nonstream_requests.inc()
161170
hooks = RawRequestResponseHooks(raw_req, self._perf_metrics_collector)
162171
response_or_generator = await entry_point(req, hooks)
172+
self._perf_metrics_collector.total_responses.inc()
163173
if req.stream:
164174
return StreamingResponse(content=response_or_generator, media_type="text/event-stream")
165175
else:
@@ -173,9 +183,11 @@ def _handle_exception(self, exception):
173183
logger.error("CppExecutorError: ", traceback.format_exc())
174184
signal.raise_signal(signal.SIGINT)
175185
elif isinstance(exception, HTTPException):
186+
self._perf_metrics_collector.http_exceptions.inc()
176187
logger.error(f"HTTPException {exception.status_code} {exception.detail}: ", traceback.format_exc())
177188
raise exception
178189
else:
190+
self._perf_metrics_collector.internal_errors.inc()
179191
logger.error("Internal server error: ", traceback.format_exc())
180192
raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}")
181193

@@ -199,13 +211,12 @@ async def __call__(self, host: str, port: int, sockets: list[socket.socket] | No
199211
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
200212
await uvicorn.Server(config).serve(sockets=sockets)
201213

202-
# TODO: rework this for service discovery, now it's only for static server list
203-
async def _set_steady_clock_offsets(self):
204-
STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset"
214+
async def _sync_server_clock(self, server: str):
215+
""" Sync the ctx/gen server's steady clock with the disagg-server's steady clock (in case NTP service is not running). """
205216
async def query_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> tuple[Optional[float], Optional[float]]:
206217
try:
207218
originate_ts = get_steady_clock_now_in_seconds()
208-
async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response:
219+
async with session.get(server_url) as response:
209220
destination_ts = get_steady_clock_now_in_seconds()
210221
if response.status == 200:
211222
response_content = await response.json()
@@ -222,12 +233,11 @@ async def query_steady_clock_offset(session: aiohttp.ClientSession, server_url:
222233

223234
async def set_steady_clock_offset(session: aiohttp.ClientSession, server_url: str, offset: float) -> None:
224235
payload = {"offset": offset}
225-
async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response:
236+
async with session.post(server_url, json=payload) as response:
226237
if response.status != 200:
227238
logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned")
228239

229240
async def align_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> None:
230-
server_url = f"http://{server_url}" if not server_url.startswith("http://") else server_url
231241
delay, offset = await query_steady_clock_offset(session, server_url)
232242
if delay is None or offset is None:
233243
logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment")
@@ -236,7 +246,13 @@ async def align_steady_clock_offset(session: aiohttp.ClientSession, server_url:
236246
# Negate the offset so that worker servers can adjust their steady clock by adding the new offset
237247
await set_steady_clock_offset(session, server_url, -offset)
238248

239-
async with aiohttp.ClientSession(
240-
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
241-
timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session:
242-
await asyncio.gather(*[align_steady_clock_offset(session, server_url) for server_url in self._ctx_servers + self._gen_servers])
249+
server_scheme = "http://" if not server.startswith("http://") else ""
250+
server_url = f"{server_scheme}{server}/steady_clock_offset"
251+
252+
try:
253+
async with aiohttp.ClientSession(
254+
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
255+
timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session:
256+
await align_steady_clock_offset(session, server_url)
257+
except (aiohttp.ClientError, OSError) as e:
258+
logger.warning(f"Unable to align steady clock offset for {server_url}: {e}; skipping adjustment")

tensorrt_llm/serve/perf_metrics.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import asyncio
1616
from collections import defaultdict, deque
1717
from dataclasses import dataclass
18-
from typing import Any, Dict, List, Literal, Optional, Union
18+
from typing import Any, Dict, List, Literal, Optional
1919

2020
from tensorrt_llm.llmapi.disagg_utils import ServerRole
2121

@@ -64,7 +64,7 @@ class MetricsDefinition:
6464
buckets: Optional[List[float]] = None
6565

6666

67-
METRICS_DEFINITIONS = [
67+
CLIENT_METRICS_DEFINITIONS = [
6868
MetricsDefinition("total_requests", "Total number of requests", "counter"),
6969
MetricsDefinition("error_requests", "Total number of error requests", "counter"),
7070
MetricsDefinition("retry_requests", "Total number of retry requests", "counter"),
@@ -96,23 +96,29 @@ class MetricsDefinition:
9696
}
9797

9898

99+
def instance_metric(definition: MetricsDefinition, role: Optional[ServerRole] = None):
100+
# import lazily to avoid breaking `set_prometheus_multiproc_dir`
101+
from prometheus_client import Counter, Histogram
102+
103+
name = (
104+
f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
105+
if role in ROLE_TO_CLIENT_TYPE
106+
else definition.name
107+
)
108+
if definition.type == "counter":
109+
return Counter(name, definition.description)
110+
elif definition.type == "histogram":
111+
return Histogram(name, definition.description, buckets=definition.buckets)
112+
else:
113+
raise ValueError(f"Invalid metric type: {definition.type}")
114+
115+
99116
class ClientMetricsCollector:
100117
def __init__(self, role: ServerRole):
101118
self._role = role
102-
# import lazily to avoid breaking `set_prometheus_multiproc_dir`
103-
from prometheus_client import Counter, Histogram
104-
105-
def instance_metric(definition: MetricsDefinition) -> Union[Counter | Histogram]:
106-
name = f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
107-
if definition.type == "counter":
108-
return Counter(name, definition.description)
109-
elif definition.type == "histogram":
110-
return Histogram(name, definition.description, buckets=definition.buckets)
111-
else:
112-
raise ValueError(f"Invalid metric type: {definition.type}")
113-
114119
self._metrics = {
115-
definition.name: instance_metric(definition) for definition in METRICS_DEFINITIONS
120+
definition.name: instance_metric(definition, role)
121+
for definition in CLIENT_METRICS_DEFINITIONS
116122
}
117123

118124
def __getattr__(
@@ -121,17 +127,41 @@ def __getattr__(
121127
return self._metrics[key]
122128

123129

130+
SERVER_METRICS_DEFINITIONS = [
131+
MetricsDefinition("total_requests", "Total number of requests", "counter"),
132+
MetricsDefinition("stream_requests", "Total number of stream requests", "counter"),
133+
MetricsDefinition("nonstream_requests", "Total number of non-stream requests", "counter"),
134+
MetricsDefinition("validation_exceptions", "Total number of validation exceptions", "counter"),
135+
MetricsDefinition("http_exceptions", "Total number of HTTP exceptions", "counter"),
136+
MetricsDefinition("internal_errors", "Total number of internal errors", "counter"),
137+
MetricsDefinition("total_responses", "Total number of responses", "counter"),
138+
MetricsDefinition(
139+
"queue_latency_seconds",
140+
"Histogram of latency from request arrival to being processed in seconds",
141+
"histogram",
142+
SHORT_TIME_BUCKETS,
143+
),
144+
]
145+
146+
124147
class DisaggPerfMetricsCollector:
125148
def __init__(self, max_requests: int):
126149
self._max_requests = max_requests
127150
self._request_meteics = deque(maxlen=max_requests)
128151
self._server_metrics = defaultdict(dict)
129152
self._lock = asyncio.Lock()
130153
self._clients = []
154+
self._metrics = {
155+
definition.name: instance_metric(definition)
156+
for definition in SERVER_METRICS_DEFINITIONS
157+
}
131158

132159
def add_client(self, client):
133160
self._clients.append(client)
134161

162+
def __getattr__(self, key: str):
163+
return self._metrics[key]
164+
135165
async def add_per_request_metrics(
136166
self,
137167
ctx_server: str,

0 commit comments

Comments
 (0)