@@ -2387,17 +2387,30 @@ def _maybe_expand_transformer_param_shape_or_error_(
23872387    def  _maybe_expand_lora_state_dict (cls , transformer , lora_state_dict ):
23882388        expanded_module_names  =  set ()
23892389        transformer_state_dict  =  transformer .state_dict ()
2390-         lora_module_names  =  sorted ({k .replace (".lora_A.weight" , "" ) for  k  in  lora_state_dict  if  "lora_A"  in  k })
2391-         is_peft_loaded  =  getattr (transformer , "peft_config" , None ) is  not None 
2390+         prefix  =  f"{ cls .transformer_name }  
2391+ 
2392+         lora_module_names  =  [
2393+             key [: - len (".lora_A.weight" )] for  key  in  lora_state_dict  if  key .endswith (".lora_A.weight" )
2394+         ]
2395+         lora_module_names  =  [name [len (prefix ) :] for  name  in  lora_module_names  if  name .startswith (prefix )]
2396+         lora_module_names  =  sorted (set (lora_module_names ))
2397+         transformer_module_names  =  sorted ({name  for  name , _  in  transformer .named_modules ()})
2398+         unexpected_modules  =  set (lora_module_names ) -  set (transformer_module_names )
2399+         if  unexpected_modules :
2400+             logger .info (f"Found unexpected modules: { unexpected_modules }  )
23922401
2402+         is_peft_loaded  =  getattr (transformer , "peft_config" , None ) is  not None 
23932403        for  k  in  lora_module_names :
2404+             if  k  in  unexpected_modules :
2405+                 continue 
2406+ 
23942407            base_param_name  =  (
23952408                f"{ k .replace (f'{ cls .transformer_name }  , '' )}  
23962409                if  is_peft_loaded 
23972410                else  f"{ k .replace (f'{ cls .transformer_name }  , '' )}  
23982411            )
23992412            base_weight_param  =  transformer_state_dict [base_param_name ]
2400-             lora_A_param  =  lora_state_dict [f"{ k }  ]
2413+             lora_A_param  =  lora_state_dict [f"{ cls . transformer_name } . { k }  ]
24012414
24022415            if  base_weight_param .shape [1 ] >  lora_A_param .shape [1 ]:
24032416                shape  =  (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
0 commit comments