@@ -334,6 +334,32 @@ def cancel_request(self, id: int):
334334 """
335335 self .executor_request_queue .enqueue_cancel_request (id )
336336
337+ def enqueue_sleep_request (self , id : int , sleep_level : int ):
338+ """
339+ Enqueue a sleep request with provided request id and sleep level
340+ Args:
341+ id (int): The request id for which to sleep
342+ sleep_level (int): The sleep level to apply to the request
343+ """
344+ self .executor_request_queue .enqueue_sleep_request (id , sleep_level )
345+
346+ def enqueue_wakeup_request (self , id : int , wakeup_level : int ):
347+ """
348+ Enqueue a wakeup request with provided request id
349+ Args:
350+ id (int): The request id for which to wakeup
351+ """
352+ self .executor_request_queue .enqueue_wakeup_request (id , wakeup_level )
353+
354+ def enqueue_update_weight_request (self , id : int , weight_ipc_handles : dict ):
355+ """
356+ Enqueue a update weight request with provided request id and weight ipc handles
357+ Args:
358+ id (int): The request id for which to update weight
359+ weight_ipc_handles (dict): The weight ipc handles to update
360+ """
361+ self .executor_request_queue .enqueue_update_weight_request (id , weight_ipc_handles )
362+
337363 def shutdown (self ):
338364 """
339365 Signals the server to shutdown.
@@ -1080,19 +1106,19 @@ def update_weight_from_ipc_handles(self, handles):
10801106
10811107 def _sleep (self , sleep_request ):
10821108 self .is_sleep_request = False
1083- self ._enqueue_responses ({ sleep_request .id : LlmResponse (request_id = sleep_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = sleep_request .id )} )
1109+ self ._enqueue_responses ([( sleep_request .id , LlmResponse (request_id = sleep_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = sleep_request .id ))] )
10841110
10851111 def _wakeup (self , wakeup_request ):
10861112 self .is_wakeup_request = False
1087- self ._enqueue_responses ({ wakeup_request .id : LlmResponse (request_id = wakeup_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = wakeup_request .id )} )
1113+ self ._enqueue_responses ([( wakeup_request .id , LlmResponse (request_id = wakeup_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = wakeup_request .id ))] )
10881114
10891115 def _update_weight (self , update_weight_request ):
10901116 self .is_update_weight_request = False
10911117
10921118 try :
10931119 self .update_weight_from_ipc_handles (update_weight_request .weight_ipc_handles )
10941120 update_weight_response = LlmResponse (request_id = update_weight_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = update_weight_request .id )
1095- self ._enqueue_responses ({ update_weight_request .id : update_weight_response } )
1121+ self ._enqueue_responses ([( update_weight_request .id , update_weight_response )] )
10961122 except Exception as e :
10971123 print (
10981124 f"Error in update_weights_from_ipc_handles: { e } "
@@ -1101,6 +1127,20 @@ def _update_weight(self, update_weight_request):
11011127 #update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id)
11021128 #self._enqueue_responses({update_weight_request.id: update_weight_response})
11031129
1130+ def _handle_control_request (self ):
1131+ if len (self .executor_request_queue .control_requests ) > 0 :
1132+ assert len (self .executor_request_queue .control_requests ) == 1 , f"control request should be the only request in the list, but got { len (self .executor_request_queue .control_requests )} "
1133+ control_request = self .executor_request_queue .control_requests .pop ()
1134+ if (control_request .is_update_weight_request ):
1135+ self ._update_weight (control_request )
1136+ elif (control_request .is_sleep_request ):
1137+ self ._sleep (control_request )
1138+ elif (control_request .is_wakeup_request ):
1139+ self ._wakeup (control_request )
1140+ else :
1141+ assert False , "Invalid control request"
1142+
1143+
11041144 def _executor_loop_overlap (self ):
11051145 torch .cuda .set_device (self .device_id )
11061146 # ensure the context is created, otherwise, some MPI calls will fail.
@@ -1122,20 +1162,10 @@ def _executor_loop_overlap(self):
11221162 iter_start_time = time .time ()
11231163
11241164 scheduled_batch , iter_stats = self ._prepare_and_schedule_batch ()
1165+ self ._handle_control_request ()
1166+
11251167 if scheduled_batch is None :
11261168 break
1127- if self .is_control_request :
1128- self .is_control_request = False
1129- assert len (new_requests ) == 1 , f"control request should be the only request in the list, but got { len (new_requests )} "
1130- if (new_requests [0 ].is_update_weight_request ()):
1131- self ._update_weight (new_requests [0 ])
1132- elif (new_requests [0 ].is_sleep_request ()):
1133- self ._sleep (new_requests [0 ])
1134- elif (new_requests [0 ].is_wakeup_request ()):
1135- self ._wakeup (new_requests [0 ])
1136- else :
1137- assert False , "Invalid control request"
1138- continue
11391169
11401170 self ._pause_requests (scheduled_batch .paused_requests )
11411171
0 commit comments