Skip to content

Commit 1cc4c7e

Browse files
fix
Co-authored-by: BenjaminBossan <[email protected]>
1 parent 94241d2 commit 1cc4c7e

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ class LoraBaseMixin:
465465
"""Utility class for handling LoRAs."""
466466

467467
_lora_loadable_modules = []
468+
_merged_adapters = set()
468469

469470
def load_lora_weights(self, **kwargs):
470471
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@@ -504,12 +505,6 @@ def _best_guess_weight_name(cls, *args, **kwargs):
504505
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
505506
return _best_guess_weight_name(*args, **kwargs)
506507

507-
@property
508-
def _merged_adapters(self):
509-
if "_merged_adapters" not in self.__dict__:
510-
self.__dict__["_merged_adapters"] = set()
511-
return self.__dict__["_merged_adapters"]
512-
513508
def unload_lora_weights(self):
514509
"""
515510
Unloads the LoRA parameters.
@@ -597,6 +592,9 @@ def fuse_lora(
597592
if len(components) == 0:
598593
raise ValueError("`components` cannot be an empty list.")
599594

595+
# Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
596+
# in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
597+
merged_adapter_names = set()
600598
for fuse_component in components:
601599
if fuse_component not in self._lora_loadable_modules:
602600
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -608,15 +606,17 @@ def fuse_lora(
608606
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
609607
for module in model.modules():
610608
if isinstance(module, BaseTunerLayer):
611-
self._merged_adapters.update(set(module.merged_adapters))
609+
merged_adapter_names.update(set(module.merged_adapters))
612610
# handle transformers models.
613611
if issubclass(model.__class__, PreTrainedModel):
614612
fuse_text_encoder_lora(
615613
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
616614
)
617615
for module in model.modules():
618616
if isinstance(module, BaseTunerLayer):
619-
self._merged_adapters.update(set(module.merged_adapters))
617+
merged_adapter_names.update(set(module.merged_adapters))
618+
619+
self._merged_adapters = self._merged_adapters | merged_adapter_names
620620

621621
def unfuse_lora(self, components: List[str] = [], **kwargs):
622622
r"""

0 commit comments

Comments
 (0)