Skip to content

Commit 1a4baaf

Browse files
committed
enable partial load
1 parent 3ad8422 commit 1a4baaf

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,9 @@ def load_single_module(name, module):
724724
for new_name in params_map[names[-1]]:
725725
fw = filter_weights('.'.join(names[:-1] + [new_name]),
726726
weights)
727+
# tmp fixes to enable partial updates in old path
728+
if not fw:
729+
continue
727730
if new_name in ['k_proj', 'v_proj']:
728731
num_kv_heads_list = [num_kv_heads
729732
] * len(fw) if isinstance(
@@ -740,15 +743,18 @@ def load_single_module(name, module):
740743
}
741744

742745
module_weights.append(fw)
743-
module.load_weights(weights=module_weights)
746+
if module_weights:
747+
module.load_weights(weights=module_weights)
748+
744749
else:
745750
module_weights = filter_weights(name, weights)
746-
if hasattr(module, 'load_weights'):
747-
module.load_weights(weights=[module_weights])
748-
else:
749-
for n, p in module._parameters.items():
750-
if p is not None:
751-
p.data.copy_(module_weights[n][:])
751+
if module_weights:
752+
if hasattr(module, 'load_weights'):
753+
module.load_weights(weights=[module_weights])
754+
else:
755+
for n, p in module._parameters.items():
756+
if p is not None:
757+
p.data.copy_(module_weights[n][:])
752758

753759
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
754760
False) in ["True", "true", "1", "yes", "y"]:

0 commit comments

Comments
 (0)