1717
1818from tensorrt_llm ._torch .pyexecutor .resource_manager import ResourceManagerType
1919from tensorrt_llm ._torch .pyexecutor .seq_slot_manager import SeqSlotManager
20+ from tensorrt_llm ._torch .utils import get_device_uuid
2021from tensorrt_llm ._utils import (customized_gc_thresholds , global_mpi_rank ,
2122 is_trace_enabled , nvtx_range , trace_func )
2223from tensorrt_llm .bindings .executor import (DisServingRequestStats ,
3334from ..speculative .drafter import Drafter
3435from .kv_cache_transceiver import KvCacheTransceiver
3536from .llm_request import (ExecutorRequest , LlmRequest , LlmRequestState ,
36- LlmResponse , executor_request_to_llm_request )
37+ LlmResponse , LlmResult , executor_request_to_llm_request )
3738from .model_engine import ModelEngine
3839from .sampler import Sampler , SampleState , SampleStateTensors , TorchSampler
3940from .scheduler import RequestScheduler , ScheduledRequests
5152PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
5253
5354SHUTDOWN_REQUEST_ID = - 1
55+ UPDATE_WEIGHT_REQUEST_ID = - 2
56+ SLEEP_REQUEST_ID = - 3
57+ WAKEUP_REQUEST_ID = - 4
5458
5559
5660@dataclasses .dataclass
@@ -59,15 +63,25 @@ class RequestQueueItem:
5963 request : Optional [ExecutorRequest ] = None
6064 is_canceled_request : bool = False
6165 query : Optional [list ] = None # only used in `StarAttention`
66+ weight_ipc_handles : Optional [dict ] = None
67+ sleep_level : Optional [int ] = None
68+ wakeup_level : Optional [int ] = None
6269
6370 @property
6471 def is_shutdown_request (self ):
6572 return self .id == SHUTDOWN_REQUEST_ID
6673
6774 @property
6875 def is_normal_request (self ):
69- return not (self .is_shutdown_request or self .is_canceled_request )
76+ return self .id > 0 and not self .is_canceled_request
77+ def is_update_weight_request (self ):
78+ return self .id == UPDATE_WEIGHT_REQUEST_ID
7079
80+ def is_sleep_request (self ):
81+ return self .id == SLEEP_REQUEST_ID
82+
83+ def is_wakeup_request (self ):
84+ return self .id == WAKEUP_REQUEST_ID
7185
7286def _get_from_request_queue (
7387 request_queue ,
@@ -244,6 +258,7 @@ def __init__(self,
244258 self .num_fetch_requests_cur_rank = 0
245259 self .num_fetch_requests = 0
246260 self .shutdown_event = threading .Event ()
261+ self .request_accumulator : List [RequestQueueItem ] = []
247262
248263 # response used data
249264 self .response_lock = threading .Lock ()
@@ -287,6 +302,8 @@ def __init__(self,
287302 self .draft_model_engine .warmup (self .resource_manager )
288303
289304 self .is_shutdown = False
305+ self .is_control_request = False
306+ self .control_request_id = 0
290307
291308 self .stats_lock = threading .Lock ()
292309 self .stats = []
@@ -465,7 +482,10 @@ def wait_shutdown(self):
465482
466483 def enqueue_request (self ,
467484 request : ExecutorRequest ,
468- query : Optional [List ] = None ):
485+ query : Optional [List ] = None ,
486+ weight_ipc_handles : Optional [dict ] = None ,
487+ sleep_level : Optional [int ] = None ,
488+ wakeup_level : Optional [int ] = None ):
469489 """
470490 Enqueue a new request, query is only used in `StarAttention`.
471491 """
@@ -476,10 +496,17 @@ def enqueue_request(self,
476496 if self .enable_iter_perf_stats :
477497 self .start_times [req_id ] = time .time ()
478498
479- if query is not None :
499+ if weight_ipc_handles is not None :
500+ self .request_queue .put (RequestQueueItem (UPDATE_WEIGHT_REQUEST_ID , None , False , None , weight_ipc_handles ))
501+ elif sleep_level is not None :
502+ self .request_queue .put (RequestQueueItem (SLEEP_REQUEST_ID , None , False , None , None , sleep_level ))
503+ elif wakeup_level is not None :
504+ self .request_queue .put (RequestQueueItem (WAKEUP_REQUEST_ID , None , False , None , None , None , wakeup_level ))
505+ elif query is not None :
480506 self .request_queue .put (RequestQueueItem (req_id , request , query ))
481507 else :
482508 self .request_queue .put (RequestQueueItem (req_id , request ))
509+ #self.request_queue.put(RequestQueueItem(req_id, request, False, query, weight_ipc_handles, sleep_level, wakeup_level))
483510 self .next_req_id += 1
484511 finally :
485512 self .enqueue_lock .release ()
@@ -756,6 +783,18 @@ def _executor_loop_pp(self):
756783 new_requests = self ._fetch_new_requests ()
757784 if self .should_stop_processing :
758785 break
786+ if self .is_control_request :
787+ self .is_control_request = False
788+ assert len (new_requests ) == 1 , f"control request should be the only request in the list, but got { len (new_requests )} "
789+ if (new_requests [0 ].is_update_weight_request ()):
790+ self ._update_weight (new_requests [0 ])
791+ elif (new_requests [0 ].is_sleep_request ()):
792+ self ._sleep (new_requests [0 ])
793+ elif (new_requests [0 ].is_wakeup_request ()):
794+ self ._wakeup (new_requests [0 ])
795+ else :
796+ assert False , "Invalid control request"
797+ continue
759798
760799 if self .enable_iter_perf_stats :
761800 iter_stats = self ._get_init_iter_stats (
@@ -907,6 +946,18 @@ def _executor_loop(self):
907946 new_requests = self ._fetch_new_requests ()
908947 if self .should_stop_processing :
909948 break
949+ if self .is_control_request :
950+ self .is_control_request = False
951+ assert len (new_requests ) == 1 , f"control request should be the only request in the list, but got { len (new_requests )} "
952+ if (new_requests [0 ].is_update_weight_request ()):
953+ self ._update_weight (new_requests [0 ])
954+ elif (new_requests [0 ].is_sleep_request ()):
955+ self ._sleep (new_requests [0 ])
956+ elif (new_requests [0 ].is_wakeup_request ()):
957+ self ._wakeup (new_requests [0 ])
958+ else :
959+ assert False , "Invalid control request"
960+ continue
910961
911962 if self .kv_cache_transceiver :
912963 self ._check_disagg_gen_transfer_status ()
@@ -1033,6 +1084,50 @@ def _prepare_draft_requests(self):
10331084 logger .error (f"Encountered an error in decode: { error_msg } " )
10341085 self ._handle_errors (error_msg )
10351086
1087+ def _sleep (self , sleep_request ):
1088+ self .is_sleep_request = False
1089+ self ._enqueue_responses ({sleep_request .id : LlmResponse (request_id = sleep_request .id , result = LlmResult (result = None , py_result = None , is_final = True ), client_id = sleep_request .id )})
1090+
1091+ def _wakeup (self , wakeup_request ):
1092+ self .is_wakeup_request = False
1093+ self ._enqueue_responses ({wakeup_request .id : LlmResponse (request_id = wakeup_request .id , result = LlmResult (result = None , py_result = None , is_final = True ), client_id = wakeup_request .id )})
1094+
1095+ def _update_weight (self , update_weight_request ):
1096+ self .is_update_weight_request = False
1097+
1098+ try :
1099+ # Get handles for this device
1100+ device_uuid = get_device_uuid (self .device_id )
1101+ handles = update_weight_request .weight_ipc_handles [device_uuid ]
1102+ weights = {}
1103+
1104+ # Process each handle to get the tensor
1105+ i = 0
1106+ for name , handle in handles :
1107+ func , args = handle
1108+ list_args = list (args )
1109+ # Update device ID to match the current device
1110+ list_args [6 ] = self .device_id
1111+ tensor = func (* list_args )
1112+ if i % 2 == 0 :
1113+ weights [name ] = tensor
1114+ else :
1115+ weights [name ] = tensor # + 1.0
1116+ i += 1
1117+
1118+ # Load weights into the model
1119+ self .model_engine .model .load_weights (weights )
1120+
1121+ torch .cuda .synchronize ()
1122+ update_weight_response = LlmResponse (request_id = update_weight_request .id , result = LlmResult (result = None , py_result = None , is_final = True ), client_id = update_weight_request .id )
1123+ self ._enqueue_responses ({update_weight_request .id : update_weight_response })
1124+ except Exception as e :
1125+ print (
1126+ f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: { e } "
1127+ )
1128+ update_weight_response = LlmResponse (request_id = update_weight_request .id , result = LlmResult (result = None , py_result = None , is_final = True ), client_id = update_weight_request .id )
1129+ self ._enqueue_responses ({update_weight_request .id : update_weight_response })
1130+
10361131 def _executor_loop_overlap (self ):
10371132 torch .cuda .set_device (self .device_id )
10381133 if self .dist .rank == 0 and not self .is_warmup and self .benchmark_req_queues_size > 0 and self .kv_cache_transceiver :
@@ -1052,6 +1147,18 @@ def _executor_loop_overlap(self):
10521147 new_requests = self ._fetch_new_requests ()
10531148 if self .should_stop_processing :
10541149 break
1150+ if self .is_control_request :
1151+ self .is_control_request = False
1152+ assert len (new_requests ) == 1 , f"control request should be the only request in the list, but got { len (new_requests )} "
1153+ if (new_requests [0 ].is_update_weight_request ()):
1154+ self ._update_weight (new_requests [0 ])
1155+ elif (new_requests [0 ].is_sleep_request ()):
1156+ self ._sleep (new_requests [0 ])
1157+ elif (new_requests [0 ].is_wakeup_request ()):
1158+ self ._wakeup (new_requests [0 ])
1159+ else :
1160+ assert False , "Invalid control request"
1161+ continue
10551162
10561163 if self .kv_cache_transceiver :
10571164 self ._check_disagg_gen_transfer_status ()
@@ -1263,20 +1370,43 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
12631370 new_requests , py_request_objects = self ._broadcast_new_requests (
12641371 new_requests , py_request_objects )
12651372
1373+ self .request_accumulator .extend (new_requests )
1374+
12661375 # drop requests arriving after shutdown
12671376 valid_new_requests = []
1268- for req_item in new_requests :
1377+ find_control_request = False
1378+ for i , req_item in enumerate (self .request_accumulator ):
12691379 if req_item .is_shutdown_request :
12701380 self .is_shutdown = True
1381+ find_control_request = True
1382+ break
1383+ if req_item .is_update_weight_request () or req_item .is_sleep_request () or req_item .is_wakeup_request ():
1384+ find_control_request = True
1385+ self .control_request_id = req_item .id
12711386 break
12721387 elif req_item .is_canceled_request :
12731388 self .canceled_req_ids .append (req_item .id )
1389+
1390+ if (find_control_request ):
1391+ if (i == 0 ):
1392+ if not self .is_shutdown :
1393+ valid_new_requests = self .request_accumulator [:1 ]
1394+ self .is_control_request = True
1395+ self .request_accumulator = self .request_accumulator [1 :]
1396+ return valid_new_requests
12741397 else :
1275- valid_new_requests .append (req_item )
1398+ valid_new_requests = self .request_accumulator [:i ]
1399+ self .request_accumulator = self .request_accumulator [i :]
1400+ else :
1401+ valid_new_requests = self .request_accumulator
1402+ self .request_accumulator = []
1403+
12761404 # Check if the beam width of the requests is equal to the max_beam_width
12771405 for req_item in valid_new_requests :
12781406 assert req_item .request .sampling_config .beam_width == self .max_beam_width , f"Request beam width { req_item .request .sampling_config .beam_width } is not equal to max_beam_width { self .max_beam_width } . This is not supported!"
12791407
1408+ new_requests = valid_new_requests
1409+
12801410 if py_request_objects and (self .dist .tp_size > 1
12811411 or self .dist .has_pp ) and self .dist .rank > 0 :
12821412 for attr_name , req_obj_dict in py_request_objects :
0 commit comments