Skip to content

Commit a2cdcda

Browse files
committed
check
1 parent da00c8d commit a2cdcda

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,18 +2460,15 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24602460
if unexpected_modules:
24612461
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
24622462

2463-
transformer_base_layer_keys = {
2464-
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k
2465-
}
2463+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
24662464
for k in lora_module_names:
24672465
if k in unexpected_modules:
24682466
continue
24692467

2470-
base_param_name = (
2471-
f"{k.replace(prefix, '')}.base_layer.weight"
2472-
if k in transformer_base_layer_keys
2473-
else f"{k.replace(prefix, '')}.weight"
2474-
)
2468+
base_param_name = f"{k.replace(prefix, '')}.weight"
2469+
base_layer_name = f"{k.replace(prefix, '')}.base_layer.weight"
2470+
if is_peft_loaded and base_layer_name in transformer_state_dict:
2471+
base_param_name = base_layer_name
24752472
base_weight_param = transformer_state_dict[base_param_name]
24762473
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
24772474

0 commit comments

Comments
 (0)