Skip to content

Commit 3785dfe

Browse files
committed
fixes
1 parent 46ba9f6 commit 3785dfe

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,17 +2387,30 @@ def _maybe_expand_transformer_param_shape_or_error_(
23872387
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
23882388
expanded_module_names = set()
23892389
transformer_state_dict = transformer.state_dict()
2390-
lora_module_names = sorted({k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k})
2391-
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2390+
prefix = f"{cls.transformer_name}."
2391+
2392+
lora_module_names = [
2393+
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
2394+
]
2395+
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
2396+
lora_module_names = sorted(set(lora_module_names))
2397+
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
2398+
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
2399+
if unexpected_modules:
2400+
logger.info(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
23922401

2402+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
23932403
for k in lora_module_names:
2404+
if k in unexpected_modules:
2405+
continue
2406+
23942407
base_param_name = (
23952408
f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight"
23962409
if is_peft_loaded
23972410
else f"{k.replace(f'{cls.transformer_name}.', '')}.weight"
23982411
)
23992412
base_weight_param = transformer_state_dict[base_param_name]
2400-
lora_A_param = lora_state_dict[f"{k}.lora_A.weight"]
2413+
lora_A_param = lora_state_dict[f"{cls.transformer_name}.{k}.lora_A.weight"]
24012414

24022415
if base_weight_param.shape[1] > lora_A_param.shape[1]:
24032416
shape = (lora_A_param.shape[0], base_weight_param.shape[1])

tests/lora/test_lora_layers_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def test_lora_parameter_expanded_shapes(self):
352352
}
353353
# We should error out because lora input features is less than original. We only
354354
# support expanding the module, not shrinking it
355-
with self.assertRaises(NotImplementedError):
355+
with self.assertRaises(RuntimeError):
356356
pipe.load_lora_weights(lora_state_dict, "adapter-1")
357357

358358
@require_peft_version_greater("0.13.2")

0 commit comments

Comments
 (0)