1515
1616from tensorrt_llm ._torch .pyexecutor .resource_manager import ResourceManagerType
1717from tensorrt_llm ._torch .pyexecutor .seq_slot_manager import SeqSlotManager
18+ from tensorrt_llm ._torch .utils import get_device_uuid
1819from tensorrt_llm ._utils import (customized_gc_thresholds , global_mpi_rank ,
1920 is_trace_enabled , nvtx_range , trace_func )
2021from tensorrt_llm .bindings .executor import (DisServingRequestStats ,
3536from .guided_decoder import GuidedDecoder
3637from .kv_cache_transceiver import KvCacheTransceiver
3738from .llm_request import (ExecutorRequest , LlmRequest , LlmRequestState ,
38- LlmResponse )
39+ LlmResponse , LlmResult , executor_request_to_llm_request , PyResult )
3940from .model_engine import ModelEngine
4041from .sampler import Sampler , SampleState , SampleStateTensors
4142from .scheduler import RequestScheduler , ScheduledRequests
@@ -184,6 +185,7 @@ def __init__(self,
184185 self .num_fetch_requests_cur_rank = 0
185186 self .num_fetch_requests = 0
186187 self .shutdown_event = threading .Event ()
188+ self .request_accumulator : List [RequestQueueItem ] = []
187189
188190 # response used data
189191 self .response_lock = threading .Lock ()
@@ -235,6 +237,8 @@ def __init__(self,
235237 )
236238 self .executor_request_queue .set_exclude_last_generation_logits (
237239 self .disable_overlap_scheduler , self .sampler )
240+ self .is_control_request = False
241+ self .control_request_id = 0
238242
239243 self .stats_lock = threading .Lock ()
240244 self .stats = []
@@ -383,12 +387,29 @@ def wait_shutdown(self):
383387
384388 def enqueue_request (self ,
385389 request : ExecutorRequest ,
386- query : Optional [List ] = None ) -> int :
390+ query : Optional [List ] = None ,
391+ weight_ipc_handles : Optional [dict ] = None ,
392+ sleep_level : Optional [int ] = None ,
393+ wakeup_level : Optional [int ] = None ) -> int :
387394 """
388395 Enqueue a new request, query is only used in `StarAttention`.
389396 """
390397 req_id = self .executor_request_queue .enqueue_request (request , query )
391398
399+ ## if weight_ipc_handles is not None:
400+ ## self.request_queue.put(RequestQueueItem(UPDATE_WEIGHT_REQUEST_ID, None, False, None, weight_ipc_handles))
401+ ## elif sleep_level is not None:
402+ ## self.request_queue.put(RequestQueueItem(SLEEP_REQUEST_ID, None, False, None, None, sleep_level))
403+ ## elif wakeup_level is not None:
404+ ## self.request_queue.put(RequestQueueItem(WAKEUP_REQUEST_ID, None, False, None, None, None, wakeup_level))
405+ ## elif query is not None:
406+ ## self.request_queue.put(RequestQueueItem(req_id, request, query))
407+ ## else:
408+ ## self.request_queue.put(RequestQueueItem(req_id, request))
409+ ## #self.request_queue.put(RequestQueueItem(req_id, request, False, query, weight_ipc_handles, sleep_level, wakeup_level))
410+ ## self.next_req_id += 1
411+ ## finally:
412+ ## self.enqueue_lock.release()
392413 return req_id
393414
394415 def set_gather_responses (self , gather_all_responses ):
@@ -666,6 +687,18 @@ def _executor_loop_pp(self):
666687 new_requests = self ._fetch_and_activate_new_requests ()
667688 if self .should_stop_processing :
668689 break
690+ if self .is_control_request :
691+ self .is_control_request = False
692+ assert len (new_requests ) == 1 , f"control request should be the only request in the list, but got { len (new_requests )} "
693+ if (new_requests [0 ].is_update_weight_request ()):
694+ self ._update_weight (new_requests [0 ])
695+ elif (new_requests [0 ].is_sleep_request ()):
696+ self ._sleep (new_requests [0 ])
697+ elif (new_requests [0 ].is_wakeup_request ()):
698+ self ._wakeup (new_requests [0 ])
699+ else :
700+ assert False , "Invalid control request"
701+ continue
669702
670703 if self .kv_cache_transceiver :
671704 self ._check_disagg_gen_transfer_status ()
@@ -914,6 +947,18 @@ def _executor_loop(self):
914947 scheduled_batch , iter_stats = self ._prepare_and_schedule_batch ()
915948 if scheduled_batch is None :
916949 break
950+ if self .is_control_request :
951+ self .is_control_request = False
952+ assert len (new_requests ) == 1 , f"control request should be the only request in the list, but got { len (new_requests )} "
953+ if (new_requests [0 ].is_update_weight_request ()):
954+ self ._update_weight (new_requests [0 ])
955+ elif (new_requests [0 ].is_sleep_request ()):
956+ self ._sleep (new_requests [0 ])
957+ elif (new_requests [0 ].is_wakeup_request ()):
958+ self ._wakeup (new_requests [0 ])
959+ else :
960+ assert False , "Invalid control request"
961+ continue
917962
918963 self ._pause_requests (scheduled_batch .paused_requests )
919964
@@ -995,6 +1040,67 @@ def _prepare_draft_requests(self):
9951040 logger .error (f"Encountered an error in decode: { error_msg } " )
9961041 self ._handle_errors (error_msg )
9971042
1043+ def update_weights (self , weights ):
1044+ # Load weights into the model
1045+ self .model_engine .model .load_weights (weights )
1046+ torch .cuda .synchronize ()
1047+
1048+ # TODO: reset prefix cache
1049+
1050+ def update_weight_from_ipc_handles (self , handles ):
1051+ """
1052+ Update model weights from IPC handles.
1053+
1054+ Args:
1055+ ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles.
1056+ {device_uuid: all_handles}
1057+ """
1058+ from tensorrt_llm ._torch .utils import get_device_uuid
1059+ device_uuid = get_device_uuid (self .device_id )
1060+
1061+ if device_uuid not in handles :
1062+ raise ValueError (f"Device UUID { device_uuid } not found in ipc_handles" )
1063+
1064+ try :
1065+ weights = {}
1066+ all_handles = handles [device_uuid ]
1067+
1068+ for param_name , tensor_handle in all_handles :
1069+ func , args = tensor_handle
1070+ list_args = list (args )
1071+ list_args [6 ] = self .device_id # Set target device
1072+ tensor = func (* list_args )
1073+ weights [param_name ] = tensor
1074+
1075+ self .update_weights (weights )
1076+
1077+ except Exception as e :
1078+ logger .error (f"failed to update weights from ipc handles: { e } " )
1079+ raise e
1080+
1081+ def _sleep (self , sleep_request ):
1082+ self .is_sleep_request = False
1083+ self ._enqueue_responses ({sleep_request .id : LlmResponse (request_id = sleep_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = sleep_request .id )})
1084+
1085+ def _wakeup (self , wakeup_request ):
1086+ self .is_wakeup_request = False
1087+ self ._enqueue_responses ({wakeup_request .id : LlmResponse (request_id = wakeup_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = wakeup_request .id )})
1088+
1089+ def _update_weight (self , update_weight_request ):
1090+ self .is_update_weight_request = False
1091+
1092+ try :
1093+ self .update_weight_from_ipc_handles (update_weight_request .weight_ipc_handles )
1094+ update_weight_response = LlmResponse (request_id = update_weight_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = update_weight_request .id )
1095+ self ._enqueue_responses ({update_weight_request .id : update_weight_response })
1096+ except Exception as e :
1097+ print (
1098+ f"Error in update_weights_from_ipc_handles: { e } "
1099+ )
1100+ raise e
1101+ #update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id)
1102+ #self._enqueue_responses({update_weight_request.id: update_weight_response})
1103+
9981104 def _executor_loop_overlap (self ):
9991105 torch .cuda .set_device (self .device_id )
10001106 # ensure the context is created, otherwise, some MPI calls will fail.
@@ -1018,6 +1124,18 @@ def _executor_loop_overlap(self):
10181124 scheduled_batch , iter_stats = self ._prepare_and_schedule_batch ()
10191125 if scheduled_batch is None :
10201126 break
1127+ if self .is_control_request :
1128+ self .is_control_request = False
1129+ assert len (new_requests ) == 1 , f"control request should be the only request in the list, but got { len (new_requests )} "
1130+ if (new_requests [0 ].is_update_weight_request ()):
1131+ self ._update_weight (new_requests [0 ])
1132+ elif (new_requests [0 ].is_sleep_request ()):
1133+ self ._sleep (new_requests [0 ])
1134+ elif (new_requests [0 ].is_wakeup_request ()):
1135+ self ._wakeup (new_requests [0 ])
1136+ else :
1137+ assert False , "Invalid control request"
1138+ continue
10211139
10221140 self ._pause_requests (scheduled_batch .paused_requests )
10231141
0 commit comments