Skip to content

Commit f099b2f

Browse files
committed
check for base_layer key in transformer state dict
1 parent 1b202c5 commit f099b2f

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,13 +2460,17 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24602460
if unexpected_modules:
24612461
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
24622462

2463-
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2463+
transformer_base_layer_keys = {
2464+
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k
2465+
}
24642466
for k in lora_module_names:
24652467
if k in unexpected_modules:
24662468
continue
24672469

24682470
base_param_name = (
2469-
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
2471+
f"{k.replace(prefix, '')}.base_layer.weight"
2472+
if k in transformer_base_layer_keys
2473+
else f"{k.replace(prefix, '')}.weight"
24702474
)
24712475
base_weight_param = transformer_state_dict[base_param_name]
24722476
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

0 commit comments

Comments
 (0)