Skip to content

Commit 7b5037f

Browse files
committed
fix things.
1 parent 11fd809 commit 7b5037f

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,12 +2337,19 @@ def _maybe_expand_transformer_param_shape_or_error_(
23372337
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
23382338
)
23392339

2340-
logger.debug(
2340+
debug_message = (
23412341
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
23422342
f"checkpoint contains higher number of features than expected. The number of input_features will be "
2343-
f"expanded from {module_in_features} to {in_features}, and the number of output features will be "
2344-
f"expanded from {module_out_features} to {out_features}."
2343+
f"expanded from {module_in_features} to {in_features}"
23452344
)
2345+
if module_out_features != out_features:
2346+
debug_message += (
2347+
", and the number of output features will be "
2348+
f"expanded from {module_out_features} to {out_features}."
2349+
)
2350+
else:
2351+
debug_message += "."
2352+
logger.debug(debug_message)
23462353

23472354
has_param_with_shape_update = True
23482355
parent_module_name, _, current_module_name = name.rpartition(".")

src/diffusers/loaders/peft.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
205205
weights.
206206
"""
207207
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
208+
from peft.tuners.tuners_utils import BaseTunerLayer
208209

209210
cache_dir = kwargs.pop("cache_dir", None)
210211
force_download = kwargs.pop("force_download", False)
@@ -316,8 +317,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
316317
if is_peft_version(">=", "0.13.1"):
317318
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
318319

319-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
320-
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
320+
# To handle scnearios where we cannot successfully set state dict. If it's unsucessful,
321+
# we should also delete the `peft_config` associated to the `adapter_name`.
322+
try:
323+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
324+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
325+
except RuntimeError as e:
326+
for module in self.modules():
327+
if isinstance(module, BaseTunerLayer):
328+
active_adapters = module.active_adapters
329+
for active_adapter in active_adapters:
330+
if adapter_name in active_adapter:
331+
module.delete_adapter(adapter_name)
332+
333+
self.peft_config.pop(adapter_name)
334+
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
335+
raise
321336

322337
warn_msg = ""
323338
if incompatible_keys is not None:

tests/lora/test_lora_layers_flux.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,14 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
435435
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
436436
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
437437
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
438+
439+
# Change the transformer config to mimic a real use case.
440+
num_channels_without_control = 4
441+
transformer = FluxTransformer2DModel.from_config(
442+
components["transformer"].config, in_channels=num_channels_without_control
443+
).to(torch_device)
444+
components["transformer"] = transformer
445+
438446
pipe = self.pipeline_class(**components)
439447
pipe = pipe.to(torch_device)
440448
pipe.set_progress_bar_config(disable=None)
@@ -453,12 +461,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
453461
}
454462
with CaptureLogger(logger) as cap_logger:
455463
pipe.load_lora_weights(lora_state_dict, "adapter-1")
456-
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
457464

465+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
466+
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
458467
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
459468
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
460469
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
461470

471+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
472+
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
473+
462474
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
463475
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
464476
lora_state_dict = {
@@ -475,13 +487,26 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
475487
lora_state_dict,
476488
"adapter-2",
477489
)
490+
# We should have `adapter-1` as the only adapter.
491+
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
492+
493+
# Check if the output is the same after lora loading error
494+
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
495+
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
478496

479497
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
480498
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
481499
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
482500
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
483501
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
484502
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
503+
# Change the transformer config to mimic a real use case.
504+
num_channels_without_control = 4
505+
transformer = FluxTransformer2DModel.from_config(
506+
components["transformer"].config, in_channels=num_channels_without_control
507+
).to(torch_device)
508+
components["transformer"] = transformer
509+
485510
pipe = self.pipeline_class(**components)
486511
pipe = pipe.to(torch_device)
487512
pipe.set_progress_bar_config(disable=None)

0 commit comments

Comments
 (0)