@@ -871,6 +871,8 @@ def load_single_module(name, module):
871871 for new_name in params_map [names [- 1 ]]:
872872 fw = filter_weights ('.' .join (names [:- 1 ] + [new_name ]),
873873 weights )
874+ if not fw :
875+ continue
874876 if new_name in ['k_proj' , 'v_proj' ]:
875877 num_kv_heads_list = [num_kv_heads
876878 ] * len (fw ) if isinstance (
@@ -887,23 +889,29 @@ def load_single_module(name, module):
887889 }
888890
889891 module_weights .append (fw )
890- module .load_weights (weights = module_weights )
892+ # Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
893+ if module_weights :
894+ module .load_weights (weights = module_weights )
895+
891896 else :
892897 module_weights = filter_weights (name , weights )
893- if hasattr (module , 'load_weights' ):
894- module .load_weights (weights = [module_weights ])
895- else :
896- for n , p in module ._parameters .items ():
897- if p is not None :
898+ # Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
899+ if module_weights :
900+ if hasattr (module , 'load_weights' ):
901+ module .load_weights (weights = [module_weights ])
902+ else :
903+ for n , p in module .named_parameters (recurse = False ):
898904 p .data .copy_ (module_weights [n ][:])
899905
900906 if os .environ .get ("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL" ,
901907 "True" ) in ["True" , "true" , "1" , "yes" , "y" ]:
902- for name , module in tqdm (list (model .named_modules ()),
908+ for name , module in tqdm (list (
909+ model .named_modules (remove_duplicate = False )),
903910 desc = "Loading weights" ):
904911 load_single_module (name , module )
905912 else :
906- all_modules = dict (model .named_modules ())
913+ # remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
914+ all_modules = dict (model .named_modules (remove_duplicate = False ))
907915 serial_load_modules = []
908916 if preload_weight_modules is not None :
909917 for module in preload_weight_modules :
@@ -919,10 +927,13 @@ def load_single_module(name, module):
919927 del all_modules [module ]
920928 pbar .close ()
921929
922- pbar = tqdm (list (model .named_modules ()),
930+ pbar = tqdm (list (model .named_modules (remove_duplicate = False )),
923931 desc = "Loading weights concurrently" )
924- args_list = [(name , module ) for name , module in model .named_modules ()
925- if name not in serial_load_modules ]
932+ args_list = [
933+ (name , module )
934+ for name , module in model .named_modules (remove_duplicate = False )
935+ if name not in serial_load_modules
936+ ]
926937 run_concurrently (load_single_module , args_list , pbar = pbar )
927938
928939
@@ -950,31 +961,36 @@ def load_single_module(name, module):
950961 if weight_mapper .does_require_special_handling (module_name ):
951962 module_weights = weight_mapper .apply_callbacks (
952963 module , module_name , module_names_breakdown , weights )
953- module .load_weights (weights = module_weights )
964+ # Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
965+ if module_weights :
966+ module .load_weights (weights = module_weights )
954967 else :
955968 module_weights = weight_mapper .filter_weights (name , weights )
956- if weight_mapper .is_special_instance_module (module ):
957- weight_mapper .handle_special_instance_module (
958- module , module_name , module_weights )
959-
960- elif hasattr (module , 'load_weights' ):
961- if "linear_attn.conv1d" in name :
962- module_weights ['weight' ] = module_weights [
963- 'weight' ].squeeze (dim = 1 )
964- module .load_weights (weights = [module_weights ])
965- else :
966- for n , p in module ._parameters .items ():
967- if p is not None :
969+ # Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
970+ if module_weights :
971+ if weight_mapper .is_special_instance_module (module ):
972+ weight_mapper .handle_special_instance_module (
973+ module , module_name , module_weights )
974+ elif hasattr (module , 'load_weights' ):
975+ if module_weights :
976+ if "linear_attn.conv1d" in name :
977+ module_weights ['weight' ] = module_weights [
978+ 'weight' ].squeeze (dim = 1 )
979+ module .load_weights (weights = [module_weights ])
980+ else :
981+ for n , p in module .named_parameters (recurse = False ):
968982 weight_mapper .handle_manual_copy (
969983 module_name , module_weights , n , p )
970984
971985 if os .environ .get ("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL" ,
972986 "True" ) in ["True" , "true" , "1" , "yes" , "y" ]:
973- for name , module in tqdm (list (model .named_modules ()),
987+ for name , module in tqdm (list (
988+ model .named_modules (remove_duplicate = False )),
974989 desc = "Loading weights" ):
975990 load_single_module (name , module )
976991 else :
977- all_modules = dict (model .named_modules ())
992+ # remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
993+ all_modules = dict (model .named_modules (remove_duplicate = False ))
978994 serial_load_modules = []
979995 if preload_weight_modules is not None :
980996 for module in preload_weight_modules :
@@ -990,8 +1006,11 @@ def load_single_module(name, module):
9901006 del all_modules [module ]
9911007 pbar .close ()
9921008
993- pbar = tqdm (list (model .named_modules ()),
1009+ pbar = tqdm (list (model .named_modules (remove_duplicate = False )),
9941010 desc = "Loading weights concurrently" )
995- args_list = [(name , module ) for name , module in model .named_modules ()
996- if name not in serial_load_modules ]
1011+ args_list = [
1012+ (name , module )
1013+ for name , module in model .named_modules (remove_duplicate = False )
1014+ if name not in serial_load_modules
1015+ ]
9971016 run_concurrently (load_single_module , args_list , pbar = pbar )
0 commit comments