@@ -724,6 +724,9 @@ def load_single_module(name, module):
724724 for new_name in params_map [names [- 1 ]]:
725725 fw = filter_weights ('.' .join (names [:- 1 ] + [new_name ]),
726726 weights )
727+ # tmp fixes to enable partial updates in old path
728+ if not fw :
729+ continue
727730 if new_name in ['k_proj' , 'v_proj' ]:
728731 num_kv_heads_list = [num_kv_heads
729732 ] * len (fw ) if isinstance (
@@ -740,15 +743,18 @@ def load_single_module(name, module):
740743 }
741744
742745 module_weights .append (fw )
743- module .load_weights (weights = module_weights )
746+ if module_weights :
747+ module .load_weights (weights = module_weights )
748+
744749 else :
745750 module_weights = filter_weights (name , weights )
746- if hasattr (module , 'load_weights' ):
747- module .load_weights (weights = [module_weights ])
748- else :
749- for n , p in module ._parameters .items ():
750- if p is not None :
751- p .data .copy_ (module_weights [n ][:])
751+ if module_weights :
752+ if hasattr (module , 'load_weights' ):
753+ module .load_weights (weights = [module_weights ])
754+ else :
755+ for n , p in module ._parameters .items ():
756+ if p is not None :
757+ p .data .copy_ (module_weights [n ][:])
752758
753759 if os .environ .get ("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL" ,
754760 False ) in ["True" , "true" , "1" , "yes" , "y" ]:
0 commit comments