@@ -1084,6 +1084,44 @@ def _prepare_draft_requests(self):
10841084 logger .error (f"Encountered an error in decode: { error_msg } " )
10851085 self ._handle_errors (error_msg )
10861086
1087+ def update_weights (self , weights ):
1088+ # Load weights into the model
1089+ self .model_engine .model .load_weights (weights )
1090+ torch .cuda .synchronize ()
1091+
1092+ # TODO: reset prefix cache
1093+
1094+ def update_weight_from_ipc_handles (self , handles ):
1095+ """
1096+ Update model weights from IPC handles.
1097+
1098+ Args:
1099+ ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles.
1100+ {device_uuid: all_handles}
1101+ """
1102+ from tensorrt_llm ._torch .utils import get_device_uuid
1103+ device_uuid = get_device_uuid (self .device_id )
1104+
1105+ if device_uuid not in handles :
1106+ raise ValueError (f"Device UUID { device_uuid } not found in ipc_handles" )
1107+
1108+ try :
1109+ weights = {}
1110+ all_handles = handles [device_uuid ]
1111+
1112+ for param_name , tensor_handle in all_handles :
1113+ func , args = tensor_handle
1114+ list_args = list (args )
1115+ list_args [6 ] = self .device_id # Set target device
1116+ tensor = func (* list_args )
1117+ weights [param_name ] = tensor
1118+
1119+ self .update_weights (weights )
1120+
1121+ except Exception as e :
1122+ logger .error (f"failed to update weights from ipc handles: { e } " )
1123+ return False
1124+
10871125 def _sleep (self , sleep_request ):
10881126 self .is_sleep_request = False
10891127 self ._enqueue_responses ({sleep_request .id : LlmResponse (request_id = sleep_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = sleep_request .id )})
@@ -1096,24 +1134,7 @@ def _update_weight(self, update_weight_request):
10961134 self .is_update_weight_request = False
10971135
10981136 try :
1099- # Get handles for this device
1100- device_uuid = get_device_uuid (self .device_id )
1101- handles = update_weight_request .weight_ipc_handles [device_uuid ]
1102- weights = {}
1103-
1104- # Process each handle to get the tensor
1105- for name , handle in handles :
1106- func , args = handle
1107- list_args = list (args )
1108- # Update device ID to match the current device
1109- list_args [6 ] = self .device_id
1110- tensor = func (* list_args )
1111- weights [name ] = tensor
1112-
1113- # Load weights into the model
1114- self .model_engine .model .load_weights (weights )
1115-
1116- torch .cuda .synchronize ()
1137+ self .update_weight_from_ipc_handles (update_weight_request .weight_ipc_handles )
11171138 update_weight_response = LlmResponse (request_id = update_weight_request .id , result = LlmResult (result = None , py_result = PyResult (0 , 0 , success = True ), is_final = True ), client_id = update_weight_request .id )
11181139 self ._enqueue_responses ({update_weight_request .id : update_weight_response })
11191140 except Exception as e :
0 commit comments