@@ -465,6 +465,7 @@ class LoraBaseMixin:
465465 """Utility class for handling LoRAs."""
466466
467467 _lora_loadable_modules = []
468+ _merged_adapters = set ()
468469
469470 def load_lora_weights (self , ** kwargs ):
470471 raise NotImplementedError ("`load_lora_weights()` is not implemented." )
@@ -504,12 +505,6 @@ def _best_guess_weight_name(cls, *args, **kwargs):
504505 deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
505506 return _best_guess_weight_name (* args , ** kwargs )
506507
507- @property
508- def _merged_adapters (self ):
509- if "_merged_adapters" not in self .__dict__ :
510- self .__dict__ ["_merged_adapters" ] = set ()
511- return self .__dict__ ["_merged_adapters" ]
512-
513508 def unload_lora_weights (self ):
514509 """
515510 Unloads the LoRA parameters.
@@ -597,6 +592,9 @@ def fuse_lora(
597592 if len (components ) == 0 :
598593 raise ValueError ("`components` cannot be an empty list." )
599594
595+ # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
596+ # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
597+ merged_adapter_names = set ()
600598 for fuse_component in components :
601599 if fuse_component not in self ._lora_loadable_modules :
602600 raise ValueError (f"{ fuse_component } is not found in { self ._lora_loadable_modules = } ." )
@@ -608,15 +606,17 @@ def fuse_lora(
608606 model .fuse_lora (lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names )
609607 for module in model .modules ():
610608 if isinstance (module , BaseTunerLayer ):
611- self . _merged_adapters .update (set (module .merged_adapters ))
609+ merged_adapter_names .update (set (module .merged_adapters ))
612610 # handle transformers models.
613611 if issubclass (model .__class__ , PreTrainedModel ):
614612 fuse_text_encoder_lora (
615613 model , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
616614 )
617615 for module in model .modules ():
618616 if isinstance (module , BaseTunerLayer ):
619- self ._merged_adapters .update (set (module .merged_adapters ))
617+ merged_adapter_names .update (set (module .merged_adapters ))
618+
619+ self ._merged_adapters = self ._merged_adapters | merged_adapter_names
620620
621621 def unfuse_lora (self , components : List [str ] = [], ** kwargs ):
622622 r"""
0 commit comments