Skip to content

Commit b7269f4

Browse files
committed
fixes
1 parent 3785dfe commit b7269f4

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,18 +2405,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24052405
continue
24062406

24072407
base_param_name = (
2408-
f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight"
2409-
if is_peft_loaded
2410-
else f"{k.replace(f'{cls.transformer_name}.', '')}.weight"
2408+
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
24112409
)
24122410
base_weight_param = transformer_state_dict[base_param_name]
2413-
lora_A_param = lora_state_dict[f"{cls.transformer_name}.{k}.lora_A.weight"]
2411+
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
24142412

24152413
if base_weight_param.shape[1] > lora_A_param.shape[1]:
24162414
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
24172415
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
24182416
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
2419-
lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight
2417+
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
24202418
expanded_module_names.add(k)
24212419
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
24222420
raise NotImplementedError(
@@ -2425,8 +2423,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24252423

24262424
if expanded_module_names:
24272425
logger.info(
2428-
f"Found some LoRA modules for which the weights were zero-padded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2426+
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
24292427
)
2428+
24302429
return lora_state_dict
24312430

24322431

tests/lora/test_lora_layers_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def test_load_regular_lora(self):
627627

628628
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
629629

630-
self.assertTrue("Found some LoRA modules for which the weights were zero-padded" in cap_logger.out)
630+
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
631631
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
632632
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
633633

0 commit comments

Comments
 (0)