@@ -2387,17 +2387,30 @@ def _maybe_expand_transformer_param_shape_or_error_(
23872387 def _maybe_expand_lora_state_dict (cls , transformer , lora_state_dict ):
23882388 expanded_module_names = set ()
23892389 transformer_state_dict = transformer .state_dict ()
2390- lora_module_names = sorted ({k .replace (".lora_A.weight" , "" ) for k in lora_state_dict if "lora_A" in k })
2391- is_peft_loaded = getattr (transformer , "peft_config" , None ) is not None
2390+ prefix = f"{ cls .transformer_name } ."
2391+
2392+ lora_module_names = [
2393+ key [: - len (".lora_A.weight" )] for key in lora_state_dict if key .endswith (".lora_A.weight" )
2394+ ]
2395+ lora_module_names = [name [len (prefix ) :] for name in lora_module_names if name .startswith (prefix )]
2396+ lora_module_names = sorted (set (lora_module_names ))
2397+ transformer_module_names = sorted ({name for name , _ in transformer .named_modules ()})
2398+ unexpected_modules = set (lora_module_names ) - set (transformer_module_names )
2399+ if unexpected_modules :
2400+ logger .info (f"Found unexpected modules: { unexpected_modules } . These will be ignored." )
23922401
2402+ is_peft_loaded = getattr (transformer , "peft_config" , None ) is not None
23932403 for k in lora_module_names :
2404+ if k in unexpected_modules :
2405+ continue
2406+
23942407 base_param_name = (
23952408 f"{ k .replace (f'{ cls .transformer_name } .' , '' )} .base_layer.weight"
23962409 if is_peft_loaded
23972410 else f"{ k .replace (f'{ cls .transformer_name } .' , '' )} .weight"
23982411 )
23992412 base_weight_param = transformer_state_dict [base_param_name ]
2400- lora_A_param = lora_state_dict [f"{ k } .lora_A.weight" ]
2413+ lora_A_param = lora_state_dict [f"{ cls . transformer_name } . { k } .lora_A.weight" ]
24012414
24022415 if base_weight_param .shape [1 ] > lora_A_param .shape [1 ]:
24032416 shape = (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
0 commit comments