@@ -1550,6 +1550,11 @@ def load_lora_weights(
15501550 )
15511551
15521552 if has_param_with_expanded_shape :
1553+ if not hasattr (self ,"_unloading_reset_list" ):
1554+ self ._lora_unloading_reset_list = [adapter_name ]
1555+ else :
1556+ self ._lora_unloading_reset_list .append (adapter_name )
1557+
15531558 logger .info (
15541559 "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
15551560 "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
@@ -1893,7 +1898,25 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
18931898 transformer .load_state_dict (transformer ._transformer_norm_layers , strict = False )
18941899 transformer ._transformer_norm_layers = None
18951900
1896- if reset_to_overwritten_params and getattr (transformer , "_overwritten_params" , None ) is not None :
1901+ if reset_to_overwritten_params and transformer is not None :
1902+ self ._maybe_reset_transformer (transformer )
1903+ self ._lora_unloading_reset_list .clear ()
1904+
1905+ def delete_adapters (self , adapter_names : Union [List [str ], str ]):
1906+ super ().delete_adapters (adapter_names )
1907+
1908+ adapter_names = [adapter_names ] if isinstance (adapter_names , str ) else adapter_names
1909+ for adapter_name in adapter_names :
1910+ if adapter_name in self ._lora_unloading_reset_list :
1911+ self ._lora_unloading_reset_list .remove (adapter_name )
1912+ # If more than 1 LoRA adapters expanded the transformer, we don't need to resest the transformer.
1913+ if len (self ._lora_unloading_reset_list ) == 0 :
1914+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer
1915+ self ._maybe_reset_transformer (transformer )
1916+
1917+ @classmethod
1918+ def _maybe_reset_transformer (cls , transformer : torch .nn .Module ):
1919+ if getattr (transformer , "_overwritten_params" , None ) is not None :
18971920 overwritten_params = transformer ._overwritten_params
18981921 module_names = set ()
18991922
@@ -2131,6 +2154,7 @@ def _get_weight_shape(weight: torch.Tensor):
21312154 raise ValueError ("Either `base_module` or `base_weight_param_name` must be provided." )
21322155
21332156
2157+
21342158# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21352159# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
21362160class AmusedLoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
0 commit comments