Skip to content

Commit af0a668

Browse files
authored
Merge pull request #2 from shuyixiong/updatew
Updatew enable partial load align interfaces with ray branch
2 parents 3ad8422 + 9835993 commit af0a668

File tree

2 files changed

+52
-25
lines changed

2 files changed

+52
-25
lines changed

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,9 @@ def load_single_module(name, module):
724724
for new_name in params_map[names[-1]]:
725725
fw = filter_weights('.'.join(names[:-1] + [new_name]),
726726
weights)
727+
# tmp fixes to enable partial updates in old path
728+
if not fw:
729+
continue
727730
if new_name in ['k_proj', 'v_proj']:
728731
num_kv_heads_list = [num_kv_heads
729732
] * len(fw) if isinstance(
@@ -740,15 +743,18 @@ def load_single_module(name, module):
740743
}
741744

742745
module_weights.append(fw)
743-
module.load_weights(weights=module_weights)
746+
if module_weights:
747+
module.load_weights(weights=module_weights)
748+
744749
else:
745750
module_weights = filter_weights(name, weights)
746-
if hasattr(module, 'load_weights'):
747-
module.load_weights(weights=[module_weights])
748-
else:
749-
for n, p in module._parameters.items():
750-
if p is not None:
751-
p.data.copy_(module_weights[n][:])
751+
if module_weights:
752+
if hasattr(module, 'load_weights'):
753+
module.load_weights(weights=[module_weights])
754+
else:
755+
for n, p in module._parameters.items():
756+
if p is not None:
757+
p.data.copy_(module_weights[n][:])
752758

753759
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
754760
False) in ["True", "true", "1", "yes", "y"]:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,44 @@ def _prepare_draft_requests(self):
10841084
logger.error(f"Encountered an error in decode: {error_msg}")
10851085
self._handle_errors(error_msg)
10861086

1087+
def update_weights(self, weights):
1088+
# Load weights into the model
1089+
self.model_engine.model.load_weights(weights)
1090+
torch.cuda.synchronize()
1091+
1092+
# TODO: reset prefix cache
1093+
1094+
def update_weight_from_ipc_handles(self, handles):
1095+
"""
1096+
Update model weights from IPC handles.
1097+
1098+
Args:
1099+
ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles.
1100+
{device_uuid: all_handles}
1101+
"""
1102+
from tensorrt_llm._torch.utils import get_device_uuid
1103+
device_uuid = get_device_uuid(self.device_id)
1104+
1105+
if device_uuid not in handles:
1106+
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
1107+
1108+
try:
1109+
weights = {}
1110+
all_handles = handles[device_uuid]
1111+
1112+
for param_name, tensor_handle in all_handles:
1113+
func, args = tensor_handle
1114+
list_args = list(args)
1115+
list_args[6] = self.device_id # Set target device
1116+
tensor = func(*list_args)
1117+
weights[param_name] = tensor
1118+
1119+
self.update_weights(weights)
1120+
1121+
except Exception as e:
1122+
logger.error(f"failed to update weights from ipc handles: {e}")
1123+
return False
1124+
10871125
def _sleep(self, sleep_request):
10881126
self.is_sleep_request = False
10891127
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)})
@@ -1096,24 +1134,7 @@ def _update_weight(self, update_weight_request):
10961134
self.is_update_weight_request = False
10971135

10981136
try:
1099-
# Get handles for this device
1100-
device_uuid = get_device_uuid(self.device_id)
1101-
handles = update_weight_request.weight_ipc_handles[device_uuid]
1102-
weights = {}
1103-
1104-
# Process each handle to get the tensor
1105-
for name, handle in handles:
1106-
func, args = handle
1107-
list_args = list(args)
1108-
# Update device ID to match the current device
1109-
list_args[6] = self.device_id
1110-
tensor = func(*list_args)
1111-
weights[name] = tensor
1112-
1113-
# Load weights into the model
1114-
self.model_engine.model.load_weights(weights)
1115-
1116-
torch.cuda.synchronize()
1137+
self.update_weight_from_ipc_handles(update_weight_request.weight_ipc_handles)
11171138
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id)
11181139
self._enqueue_responses({update_weight_request.id: update_weight_response})
11191140
except Exception as e:

0 commit comments

Comments
 (0)