Skip to content

Commit ea6cd76

Browse files
authored
[None][refactor] simplify get_stats and get_kvcache_events with rpc (#9980)
Signed-off-by: Yan Chunwei <[email protected]> Signed-off-by: Superjomn <[email protected]>
1 parent c87f1a6 commit ea6cd76

File tree

10 files changed

+400
-414
lines changed

10 files changed

+400
-414
lines changed

tensorrt_llm/executor/proxy.py

Lines changed: 125 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import atexit
22
import concurrent.futures
3+
import json
4+
import os
35
import threading
4-
import time
56
import weakref
6-
from typing import Dict, Optional, Union
7+
from typing import Dict, List, Optional
78

89
import torch
910
import zmq
@@ -22,9 +23,11 @@
2223
from .postproc_worker import PostprocWorker, PostprocWorkerConfig
2324
from .request import CancellingRequest, GenerationRequest
2425
from .result import GenerationResult, IterationResult
25-
from .utils import (ErrorResponse, IntraProcessQueue, WorkerCommIpcAddrs,
26-
create_mpi_comm_session, get_spawn_proxy_process_env,
27-
is_llm_response, print_alive_threads)
26+
from .rpc import RPCClient
27+
from .rpc.rpc_common import get_unique_ipc_addr
28+
from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session,
29+
get_spawn_proxy_process_env, is_llm_response,
30+
print_alive_threads)
2831
from .worker import GenerationExecutorWorker, worker_main
2932

3033
__all__ = [
@@ -89,19 +92,27 @@ def __init__(
8992
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
9093
"llm_args", None) is not None else None
9194

95+
# Generate RPC address and key for stats RPC
96+
self.rpc_addr = get_unique_ipc_addr()
97+
self.hmac_key = os.urandom(32)
98+
9299
worker_kwargs = dict(**worker_kwargs,
93100
worker_queues=self._setup_queues(),
94101
postproc_worker_config=postproc_worker_config,
95-
is_llm_executor=False)
102+
is_llm_executor=False,
103+
rpc_addr=self.rpc_addr,
104+
hmac_key=self.hmac_key)
96105

97106
if "log_level" not in worker_kwargs:
98107
worker_kwargs["log_level"] = logger.level
99108

100109
self.dispatch_result_thread: Optional[ManagedThread] = None
101-
self.dispatch_stats_thread: Optional[ManagedThread] = None
102-
self.dispatch_kv_cache_events_thread: Optional[ManagedThread] = None
110+
self.rpc_client: Optional[RPCClient] = None
103111
self._start_executor_workers(worker_kwargs)
104112

113+
# Create RPC client after workers are started (worker starts RPC server)
114+
self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key)
115+
105116
# MPI registers its joiner using threading._register_atexit if possible.
106117
# These functions run before atexit.register, so to avoid deadlock,
107118
# we have to notify workers to exit before MPI starts to wait them.
@@ -128,19 +139,11 @@ def _setup_queues(self) -> WorkerCommIpcAddrs:
128139
socket_type=zmq.PULL
129140
if self.enable_postprocess_parallel else zmq.PAIR,
130141
name="proxy_result_queue")
131-
self.mp_stats_queue = FusedIpcQueue(is_server=True,
132-
fuse_message=False,
133-
name="proxy_stats_queue")
134-
self.kv_cache_events_queue = FusedIpcQueue(
135-
is_server=True,
136-
fuse_message=False,
137-
name="proxy_kv_cache_events_queue")
142+
# Stats and KV events are now fetched via RPC, not IPC queues.
138143
return WorkerCommIpcAddrs(
139144
request_queue_addr=self.request_queue.address,
140145
worker_init_status_queue_addr=self.worker_init_status_queue.address,
141146
result_queue_addr=self.result_queue.address,
142-
stats_queue_addr=self.mp_stats_queue.address,
143-
kv_cache_events_queue_addr=self.kv_cache_events_queue.address,
144147
)
145148

146149
def abort_request(self, request_id: int) -> None:
@@ -204,71 +207,8 @@ def process_res(res):
204207

205208
return True # success
206209

207-
def _iteration_result_task(self,
208-
queue: Union[FusedIpcQueue, IntraProcessQueue],
209-
result_singleton: IterationResult,
210-
urgent: bool = False) -> bool:
211-
if not urgent:
212-
time.sleep(0.2)
213-
214-
try:
215-
data = queue.get()
216-
except:
217-
logger.debug(
218-
"proxy.py: Error in _iteration_result_task: queue.get()")
219-
return False
220-
221-
if data is None:
222-
logger.debug("proxy.py: _iteration_result_task: data is None")
223-
return False # shutdown the thread
224-
225-
data = data if isinstance(data, list) else [data]
226-
queue = result_singleton.queue
227-
async_queues = []
228-
229-
while queue.full():
230-
queue.get()
231-
232-
try:
233-
for d in data:
234-
if d is None:
235-
logger.debug("proxy.py: _iteration_result_task: d is None")
236-
return False
237-
238-
if isinstance(queue, _SyncQueue):
239-
queue.put_nowait(d)
240-
async_queues.append(queue)
241-
else:
242-
queue.put(d)
243-
244-
if async_queues:
245-
_SyncQueue.notify_many(queue.loop, async_queues)
246-
247-
except AsyncQueue.EventLoopShutdownError:
248-
# This happens in the last loop while the generate workflow is
249-
# stopped, or when get_stats() or aget_stats() are not called by users
250-
# and therefore event loop can already be closed.
251-
logger.debug("proxy.py: EventLoopShutdownError")
252-
except Exception as e:
253-
logger.debug(f"proxy.py: Error in _iteration_result_task: {e}")
254-
raise e
255-
256-
return True # success
257-
258-
def dispatch_stats_task(self) -> bool:
259-
if not self._iter_stats_result:
260-
# This can happen temporarily because the WAR in tensorrt_llm/bench/benchmark/throughput.py
261-
# is not synchronized with self.dispatch_stats_thread.
262-
logger.debug(
263-
f"Skipping stats dispatch while self._iter_stats_result=None")
264-
return True # Intended behavior, not an error
265-
return self._iteration_result_task(self.mp_stats_queue,
266-
self._iter_stats_result)
267-
268-
def dispatch_kv_cache_events_task(self) -> bool:
269-
return self._iteration_result_task(self.kv_cache_events_queue,
270-
self._iter_kv_events_result,
271-
urgent=True)
210+
# NOTE: _iteration_result_task, dispatch_stats_task, and dispatch_kv_cache_events_task
211+
# have been removed as stats and kv_events are now fetched via RPC directly.
272212

273213
def _start_dispatch_threads(self):
274214
if self.dispatch_result_thread is None:
@@ -277,25 +217,9 @@ def _start_dispatch_threads(self):
277217
weakref.WeakMethod(self.dispatch_result_task),
278218
error_queue=self._error_queue,
279219
name="proxy_dispatch_result_thread")
280-
self.dispatch_stats_thread = ManagedThread(
281-
weakref.WeakMethod(self.dispatch_stats_task),
282-
error_queue=self._error_queue,
283-
name="proxy_dispatch_stats_thread")
284-
self.dispatch_kv_cache_events_thread = ManagedThread(
285-
weakref.WeakMethod(self.dispatch_kv_cache_events_task),
286-
error_queue=self._error_queue,
287-
name="proxy_dispatch_kv_cache_events_thread")
288220

289221
self.dispatch_result_thread.start()
290222

291-
# Only collect stats when submission
292-
# is via LLM API
293-
if self._iter_stats_result:
294-
self.dispatch_stats_thread.start()
295-
296-
if self._iter_kv_events_result:
297-
self.dispatch_kv_cache_events_thread.start()
298-
299223
self._handle_background_error()
300224

301225
def _start_executor_workers(self, worker_kwargs):
@@ -387,23 +311,18 @@ def shutdown(self):
387311
):
388312
self.dispatch_result_thread.stop()
389313
self.dispatch_result_thread.join()
390-
if self.dispatch_stats_thread is not None and self.dispatch_stats_thread.is_alive(
391-
):
392-
self.dispatch_stats_thread.stop()
393-
self.dispatch_stats_thread.join()
394-
if self.dispatch_kv_cache_events_thread is not None and self.dispatch_kv_cache_events_thread.is_alive(
395-
):
396-
self.dispatch_kv_cache_events_thread.stop()
397-
self.dispatch_kv_cache_events_thread.join()
398314

399315
# step3: finish all remaining work
400316

317+
# close the RPC client
318+
if self.rpc_client is not None:
319+
self.rpc_client.close()
320+
self.rpc_client = None
321+
401322
# close all the sockets
402323
self.request_queue.close()
403324
self.worker_init_status_queue.close()
404325
self.result_queue.close()
405-
self.mp_stats_queue.close()
406-
self.kv_cache_events_queue.close()
407326

408327
self.workers_started = False
409328
self.mpi_session.shutdown()
@@ -441,6 +360,104 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
441360

442361
return result
443362

363+
def get_stats(self, timeout: float) -> List[dict]:
364+
"""Get iteration statistics from the runtime via RPC.
365+
366+
Args:
367+
timeout (float): Max wait time in seconds for the RPC call.
368+
369+
Returns:
370+
List[dict]: A list of runtime stats as dict.
371+
"""
372+
if self.rpc_client is None:
373+
logger.warning("RPC client not initialized, cannot get stats")
374+
return []
375+
376+
stats = self.rpc_client.fetch_stats_wait_async(timeout=timeout).remote()
377+
return [json.loads(s) if isinstance(s, str) else s for s in stats]
378+
379+
def aget_stats(self, timeout: float) -> IterationResult:
380+
"""Get iteration statistics from the runtime via RPC (async).
381+
382+
Args:
383+
timeout (float): Max wait time in seconds for the RPC call.
384+
385+
Returns:
386+
IterationResult: An async iterable object containing runtime stats.
387+
"""
388+
# Initialize iteration result if needed
389+
self._maybe_initialize_iteration_results()
390+
391+
if self._iter_stats_result is None:
392+
logger.warning("Iteration statistics are not available yet.")
393+
from .executor import empty_async_iterable
394+
return empty_async_iterable()
395+
396+
# Fetch stats via RPC and populate the result
397+
try:
398+
stats = self.rpc_client.fetch_stats_wait_async(
399+
timeout=timeout).remote()
400+
except Exception as e:
401+
logger.debug(f"Error fetching stats via RPC: {e}")
402+
stats = []
403+
404+
for stat in stats:
405+
self._iter_stats_result.queue.put(stat)
406+
407+
self._iter_stats_result.set_timeout(timeout)
408+
return self._iter_stats_result
409+
410+
def get_kv_events(self, timeout: float) -> List[dict]:
411+
"""Get iteration KV events from the runtime via RPC.
412+
413+
Args:
414+
timeout (float): Max wait time in seconds for the RPC call.
415+
416+
Returns:
417+
List[dict]: A list of runtime events as dict.
418+
"""
419+
if self.rpc_client is None:
420+
logger.warning("RPC client not initialized, cannot get kv events")
421+
return []
422+
423+
try:
424+
events = self.rpc_client.fetch_kv_cache_events_wait_async(
425+
timeout=timeout).remote()
426+
return [json.loads(e) if isinstance(e, str) else e for e in events]
427+
except Exception as e:
428+
logger.error(f"Error fetching kv events via RPC: {e}")
429+
return []
430+
431+
def aget_kv_events(self, timeout: float) -> IterationResult:
432+
"""Get iteration KV events from the runtime via RPC (async).
433+
434+
Args:
435+
timeout (float): Max wait time in seconds for the RPC call.
436+
437+
Returns:
438+
IterationResult: An async iterable object containing runtime events.
439+
"""
440+
# Initialize iteration result if needed
441+
self._maybe_initialize_iteration_results()
442+
443+
if self._iter_kv_events_result is None:
444+
from .executor import empty_async_iterable
445+
return empty_async_iterable()
446+
447+
# Fetch kv events via RPC and populate the result
448+
try:
449+
events = self.rpc_client.fetch_kv_cache_events_wait_async(
450+
timeout=timeout).remote()
451+
except Exception as e:
452+
logger.debug(f"Error fetching kv events via RPC: {e}")
453+
events = []
454+
455+
for event in events:
456+
self._iter_kv_events_result.queue.put(event)
457+
458+
self._iter_kv_events_result.set_timeout(timeout)
459+
return self._iter_kv_events_result
460+
444461
def __del__(self):
445462
self.shutdown()
446463

0 commit comments

Comments
 (0)