11import atexit
22import concurrent .futures
3+ import json
4+ import os
35import threading
4- import time
56import weakref
6- from typing import Dict , Optional , Union
7+ from typing import Dict , List , Optional
78
89import torch
910import zmq
2223from .postproc_worker import PostprocWorker , PostprocWorkerConfig
2324from .request import CancellingRequest , GenerationRequest
2425from .result import GenerationResult , IterationResult
25- from .utils import (ErrorResponse , IntraProcessQueue , WorkerCommIpcAddrs ,
26- create_mpi_comm_session , get_spawn_proxy_process_env ,
27- is_llm_response , print_alive_threads )
26+ from .rpc import RPCClient
27+ from .rpc .rpc_common import get_unique_ipc_addr
28+ from .utils import (ErrorResponse , WorkerCommIpcAddrs , create_mpi_comm_session ,
29+ get_spawn_proxy_process_env , is_llm_response ,
30+ print_alive_threads )
2831from .worker import GenerationExecutorWorker , worker_main
2932
3033__all__ = [
@@ -89,19 +92,27 @@ def __init__(
8992 "llm_args" ].garbage_collection_gen0_threshold if worker_kwargs .get (
9093 "llm_args" , None ) is not None else None
9194
95+ # Generate RPC address and key for stats RPC
96+ self .rpc_addr = get_unique_ipc_addr ()
97+ self .hmac_key = os .urandom (32 )
98+
9299 worker_kwargs = dict (** worker_kwargs ,
93100 worker_queues = self ._setup_queues (),
94101 postproc_worker_config = postproc_worker_config ,
95- is_llm_executor = False )
102+ is_llm_executor = False ,
103+ rpc_addr = self .rpc_addr ,
104+ hmac_key = self .hmac_key )
96105
97106 if "log_level" not in worker_kwargs :
98107 worker_kwargs ["log_level" ] = logger .level
99108
100109 self .dispatch_result_thread : Optional [ManagedThread ] = None
101- self .dispatch_stats_thread : Optional [ManagedThread ] = None
102- self .dispatch_kv_cache_events_thread : Optional [ManagedThread ] = None
110+ self .rpc_client : Optional [RPCClient ] = None
103111 self ._start_executor_workers (worker_kwargs )
104112
113+ # Create RPC client after workers are started (worker starts RPC server)
114+ self .rpc_client = RPCClient (self .rpc_addr , hmac_key = self .hmac_key )
115+
105116 # MPI registers its joiner using threading._register_atexit if possible.
106117 # These functions run before atexit.register, so to avoid deadlock,
107118 # we have to notify workers to exit before MPI starts to wait them.
@@ -128,19 +139,11 @@ def _setup_queues(self) -> WorkerCommIpcAddrs:
128139 socket_type = zmq .PULL
129140 if self .enable_postprocess_parallel else zmq .PAIR ,
130141 name = "proxy_result_queue" )
131- self .mp_stats_queue = FusedIpcQueue (is_server = True ,
132- fuse_message = False ,
133- name = "proxy_stats_queue" )
134- self .kv_cache_events_queue = FusedIpcQueue (
135- is_server = True ,
136- fuse_message = False ,
137- name = "proxy_kv_cache_events_queue" )
142+ # Stats and KV events are now fetched via RPC, not IPC queues.
138143 return WorkerCommIpcAddrs (
139144 request_queue_addr = self .request_queue .address ,
140145 worker_init_status_queue_addr = self .worker_init_status_queue .address ,
141146 result_queue_addr = self .result_queue .address ,
142- stats_queue_addr = self .mp_stats_queue .address ,
143- kv_cache_events_queue_addr = self .kv_cache_events_queue .address ,
144147 )
145148
146149 def abort_request (self , request_id : int ) -> None :
@@ -204,71 +207,8 @@ def process_res(res):
204207
205208 return True # success
206209
207- def _iteration_result_task (self ,
208- queue : Union [FusedIpcQueue , IntraProcessQueue ],
209- result_singleton : IterationResult ,
210- urgent : bool = False ) -> bool :
211- if not urgent :
212- time .sleep (0.2 )
213-
214- try :
215- data = queue .get ()
216- except :
217- logger .debug (
218- "proxy.py: Error in _iteration_result_task: queue.get()" )
219- return False
220-
221- if data is None :
222- logger .debug ("proxy.py: _iteration_result_task: data is None" )
223- return False # shutdown the thread
224-
225- data = data if isinstance (data , list ) else [data ]
226- queue = result_singleton .queue
227- async_queues = []
228-
229- while queue .full ():
230- queue .get ()
231-
232- try :
233- for d in data :
234- if d is None :
235- logger .debug ("proxy.py: _iteration_result_task: d is None" )
236- return False
237-
238- if isinstance (queue , _SyncQueue ):
239- queue .put_nowait (d )
240- async_queues .append (queue )
241- else :
242- queue .put (d )
243-
244- if async_queues :
245- _SyncQueue .notify_many (queue .loop , async_queues )
246-
247- except AsyncQueue .EventLoopShutdownError :
248- # This happens in the last loop while the generate workflow is
249- # stopped, or when get_stats() or aget_stats() are not called by users
250- # and therefore event loop can already be closed.
251- logger .debug ("proxy.py: EventLoopShutdownError" )
252- except Exception as e :
253- logger .debug (f"proxy.py: Error in _iteration_result_task: { e } " )
254- raise e
255-
256- return True # success
257-
258- def dispatch_stats_task (self ) -> bool :
259- if not self ._iter_stats_result :
260- # This can happen temporarily because the WAR in tensorrt_llm/bench/benchmark/throughput.py
261- # is not synchronized with self.dispatch_stats_thread.
262- logger .debug (
263- f"Skipping stats dispatch while self._iter_stats_result=None" )
264- return True # Intended behavior, not an error
265- return self ._iteration_result_task (self .mp_stats_queue ,
266- self ._iter_stats_result )
267-
268- def dispatch_kv_cache_events_task (self ) -> bool :
269- return self ._iteration_result_task (self .kv_cache_events_queue ,
270- self ._iter_kv_events_result ,
271- urgent = True )
210+ # NOTE: _iteration_result_task, dispatch_stats_task, and dispatch_kv_cache_events_task
211+ # have been removed as stats and kv_events are now fetched via RPC directly.
272212
273213 def _start_dispatch_threads (self ):
274214 if self .dispatch_result_thread is None :
@@ -277,25 +217,9 @@ def _start_dispatch_threads(self):
277217 weakref .WeakMethod (self .dispatch_result_task ),
278218 error_queue = self ._error_queue ,
279219 name = "proxy_dispatch_result_thread" )
280- self .dispatch_stats_thread = ManagedThread (
281- weakref .WeakMethod (self .dispatch_stats_task ),
282- error_queue = self ._error_queue ,
283- name = "proxy_dispatch_stats_thread" )
284- self .dispatch_kv_cache_events_thread = ManagedThread (
285- weakref .WeakMethod (self .dispatch_kv_cache_events_task ),
286- error_queue = self ._error_queue ,
287- name = "proxy_dispatch_kv_cache_events_thread" )
288220
289221 self .dispatch_result_thread .start ()
290222
291- # Only collect stats when submission
292- # is via LLM API
293- if self ._iter_stats_result :
294- self .dispatch_stats_thread .start ()
295-
296- if self ._iter_kv_events_result :
297- self .dispatch_kv_cache_events_thread .start ()
298-
299223 self ._handle_background_error ()
300224
301225 def _start_executor_workers (self , worker_kwargs ):
@@ -387,23 +311,18 @@ def shutdown(self):
387311 ):
388312 self .dispatch_result_thread .stop ()
389313 self .dispatch_result_thread .join ()
390- if self .dispatch_stats_thread is not None and self .dispatch_stats_thread .is_alive (
391- ):
392- self .dispatch_stats_thread .stop ()
393- self .dispatch_stats_thread .join ()
394- if self .dispatch_kv_cache_events_thread is not None and self .dispatch_kv_cache_events_thread .is_alive (
395- ):
396- self .dispatch_kv_cache_events_thread .stop ()
397- self .dispatch_kv_cache_events_thread .join ()
398314
399315 # step3: finish all remaining work
400316
317+ # close the RPC client
318+ if self .rpc_client is not None :
319+ self .rpc_client .close ()
320+ self .rpc_client = None
321+
401322 # close all the sockets
402323 self .request_queue .close ()
403324 self .worker_init_status_queue .close ()
404325 self .result_queue .close ()
405- self .mp_stats_queue .close ()
406- self .kv_cache_events_queue .close ()
407326
408327 self .workers_started = False
409328 self .mpi_session .shutdown ()
@@ -441,6 +360,104 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
441360
442361 return result
443362
363+ def get_stats (self , timeout : float ) -> List [dict ]:
364+ """Get iteration statistics from the runtime via RPC.
365+
366+ Args:
367+ timeout (float): Max wait time in seconds for the RPC call.
368+
369+ Returns:
370+ List[dict]: A list of runtime stats as dict.
371+ """
372+ if self .rpc_client is None :
373+ logger .warning ("RPC client not initialized, cannot get stats" )
374+ return []
375+
376+ stats = self .rpc_client .fetch_stats_wait_async (timeout = timeout ).remote ()
377+ return [json .loads (s ) if isinstance (s , str ) else s for s in stats ]
378+
379+ def aget_stats (self , timeout : float ) -> IterationResult :
380+ """Get iteration statistics from the runtime via RPC (async).
381+
382+ Args:
383+ timeout (float): Max wait time in seconds for the RPC call.
384+
385+ Returns:
386+ IterationResult: An async iterable object containing runtime stats.
387+ """
388+ # Initialize iteration result if needed
389+ self ._maybe_initialize_iteration_results ()
390+
391+ if self ._iter_stats_result is None :
392+ logger .warning ("Iteration statistics are not available yet." )
393+ from .executor import empty_async_iterable
394+ return empty_async_iterable ()
395+
396+ # Fetch stats via RPC and populate the result
397+ try :
398+ stats = self .rpc_client .fetch_stats_wait_async (
399+ timeout = timeout ).remote ()
400+ except Exception as e :
401+ logger .debug (f"Error fetching stats via RPC: { e } " )
402+ stats = []
403+
404+ for stat in stats :
405+ self ._iter_stats_result .queue .put (stat )
406+
407+ self ._iter_stats_result .set_timeout (timeout )
408+ return self ._iter_stats_result
409+
410+ def get_kv_events (self , timeout : float ) -> List [dict ]:
411+ """Get iteration KV events from the runtime via RPC.
412+
413+ Args:
414+ timeout (float): Max wait time in seconds for the RPC call.
415+
416+ Returns:
417+ List[dict]: A list of runtime events as dict.
418+ """
419+ if self .rpc_client is None :
420+ logger .warning ("RPC client not initialized, cannot get kv events" )
421+ return []
422+
423+ try :
424+ events = self .rpc_client .fetch_kv_cache_events_wait_async (
425+ timeout = timeout ).remote ()
426+ return [json .loads (e ) if isinstance (e , str ) else e for e in events ]
427+ except Exception as e :
428+ logger .error (f"Error fetching kv events via RPC: { e } " )
429+ return []
430+
431+ def aget_kv_events (self , timeout : float ) -> IterationResult :
432+ """Get iteration KV events from the runtime via RPC (async).
433+
434+ Args:
435+ timeout (float): Max wait time in seconds for the RPC call.
436+
437+ Returns:
438+ IterationResult: An async iterable object containing runtime events.
439+ """
440+ # Initialize iteration result if needed
441+ self ._maybe_initialize_iteration_results ()
442+
443+ if self ._iter_kv_events_result is None :
444+ from .executor import empty_async_iterable
445+ return empty_async_iterable ()
446+
447+ # Fetch kv events via RPC and populate the result
448+ try :
449+ events = self .rpc_client .fetch_kv_cache_events_wait_async (
450+ timeout = timeout ).remote ()
451+ except Exception as e :
452+ logger .debug (f"Error fetching kv events via RPC: { e } " )
453+ events = []
454+
455+ for event in events :
456+ self ._iter_kv_events_result .queue .put (event )
457+
458+ self ._iter_kv_events_result .set_timeout (timeout )
459+ return self ._iter_kv_events_result
460+
444461 def __del__ (self ):
445462 self .shutdown ()
446463
0 commit comments