Skip to content

Commit 9835993

Browse files
committed
align interfaces with ray branch
1 parent 1a4baaf commit 9835993

File tree

1 file changed

+39
-18
lines changed

1 file changed

+39
-18
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,44 @@ def _prepare_draft_requests(self):
10841084
logger.error(f"Encountered an error in decode: {error_msg}")
10851085
self._handle_errors(error_msg)
10861086

1087+
def update_weights(self, weights):
1088+
# Load weights into the model
1089+
self.model_engine.model.load_weights(weights)
1090+
torch.cuda.synchronize()
1091+
1092+
# TODO: reset prefix cache
1093+
1094+
def update_weight_from_ipc_handles(self, handles):
1095+
"""
1096+
Update model weights from IPC handles.
1097+
1098+
Args:
1099+
ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles.
1100+
{device_uuid: all_handles}
1101+
"""
1102+
from tensorrt_llm._torch.utils import get_device_uuid
1103+
device_uuid = get_device_uuid(self.device_id)
1104+
1105+
if device_uuid not in handles:
1106+
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
1107+
1108+
try:
1109+
weights = {}
1110+
all_handles = handles[device_uuid]
1111+
1112+
for param_name, tensor_handle in all_handles:
1113+
func, args = tensor_handle
1114+
list_args = list(args)
1115+
list_args[6] = self.device_id # Set target device
1116+
tensor = func(*list_args)
1117+
weights[param_name] = tensor
1118+
1119+
self.update_weights(weights)
1120+
1121+
except Exception as e:
1122+
logger.error(f"failed to update weights from ipc handles: {e}")
1123+
return False
1124+
10871125
def _sleep(self, sleep_request):
10881126
self.is_sleep_request = False
10891127
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)})
@@ -1096,24 +1134,7 @@ def _update_weight(self, update_weight_request):
10961134
self.is_update_weight_request = False
10971135

10981136
try:
1099-
# Get handles for this device
1100-
device_uuid = get_device_uuid(self.device_id)
1101-
handles = update_weight_request.weight_ipc_handles[device_uuid]
1102-
weights = {}
1103-
1104-
# Process each handle to get the tensor
1105-
for name, handle in handles:
1106-
func, args = handle
1107-
list_args = list(args)
1108-
# Update device ID to match the current device
1109-
list_args[6] = self.device_id
1110-
tensor = func(*list_args)
1111-
weights[name] = tensor
1112-
1113-
# Load weights into the model
1114-
self.model_engine.model.load_weights(weights)
1115-
1116-
torch.cuda.synchronize()
1137+
self.update_weight_from_ipc_handles(update_weight_request.weight_ipc_handles)
11171138
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id)
11181139
self._enqueue_responses({update_weight_request.id: update_weight_response})
11191140
except Exception as e:

0 commit comments

Comments
 (0)