Skip to content

Commit 143df0c

Browse files
committed
fix stuff
1 parent eb2ad02 commit 143df0c

File tree

2 files changed

+12
-15
lines changed

2 files changed

+12
-15
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,15 +2312,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
23122312

23132313
# Expand transformer parameter shapes if they don't match lora
23142314
has_param_with_shape_update = False
2315+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
23152316
for name, module in transformer.named_modules():
23162317
if isinstance(module, torch.nn.Linear):
23172318
module_weight = module.weight.data
23182319
module_bias = module.bias.data if module.bias is not None else None
23192320
bias = module_bias is not None
23202321

2321-
lora_A_weight_name = f"{name}.lora_A.weight"
2322-
lora_B_weight_name = f"{name}.lora_B.weight"
2323-
if lora_A_weight_name not in state_dict.keys():
2322+
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
2323+
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
2324+
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
2325+
if lora_A_weight_name not in state_dict:
23242326
continue
23252327

23262328
in_features = state_dict[lora_A_weight_name].shape[1]

tests/lora/test_lora_layers_flux.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -467,18 +467,14 @@ def test_lora_expanding_shape_with_normal_lora(self):
467467
pipe.load_lora_weights(lora_state_dict, "adapter-2")
468468

469469
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
470+
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
470471
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
471472

472473
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
473-
474-
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
475474
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
476475

477476
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
478-
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
479-
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
480-
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
481-
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
477+
# This should raise a runtime error on input shapes being incompatible.
482478
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
483479
# Change the transformer config to mimic a real use case.
484480
num_channels_without_control = 4
@@ -501,24 +497,23 @@ def test_lora_expanding_shape_with_normal_lora(self):
501497
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
502498
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
503499
}
504-
505-
with CaptureLogger(logger) as cap_logger:
506-
pipe.load_lora_weights(lora_state_dict, "adapter-1")
500+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
507501

508502
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
509503
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
510504
self.assertTrue(pipe.transformer.config.in_channels == in_features)
511-
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
512505

513506
lora_state_dict = {
514507
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
515508
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
516509
}
517510

518-
# We should check for input shapes being incompatible here.
511+
# We should check for input shapes being incompatible here. But because above mentioned issue is
512+
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
513+
# mismatch error.
519514
self.assertRaisesRegex(
520515
RuntimeError,
521-
"x_embedder.lora_A.weight",
516+
"size mismatch for x_embedder.lora_A.adapter-2.weight",
522517
pipe.load_lora_weights,
523518
lora_state_dict,
524519
"adapter-2",

0 commit comments

Comments
 (0)