Skip to content

Commit 0193f2c

Browse files
authored
Update lora_pipeline.py
1 parent 2c1ed50 commit 0193f2c

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,11 @@ def load_lora_weights(
15501550
)
15511551

15521552
if has_param_with_expanded_shape:
1553+
if not hasattr(self,"_unloading_reset_list"):
1554+
self._lora_unloading_reset_list = [adapter_name]
1555+
else:
1556+
self._lora_unloading_reset_list.append(adapter_name)
1557+
15531558
logger.info(
15541559
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
15551560
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
@@ -1893,7 +1898,25 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
18931898
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
18941899
transformer._transformer_norm_layers = None
18951900

1896-
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
1901+
if reset_to_overwritten_params and transformer is not None:
1902+
self._maybe_reset_transformer(transformer)
1903+
self._lora_unloading_reset_list.clear()
1904+
1905+
def delete_adapters(self, adapter_names: Union[List[str], str]):
1906+
super().delete_adapters(adapter_names)
1907+
1908+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
1909+
for adapter_name in adapter_names:
1910+
if adapter_name in self._lora_unloading_reset_list:
1911+
self._lora_unloading_reset_list.remove(adapter_name)
1912+
# If more than 1 LoRA adapters expanded the transformer, we don't need to resest the transformer.
1913+
if len(self._lora_unloading_reset_list) == 0:
1914+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1915+
self._maybe_reset_transformer(transformer)
1916+
1917+
@classmethod
1918+
def _maybe_reset_transformer(cls, transformer: torch.nn.Module):
1919+
if getattr(transformer, "_overwritten_params", None) is not None:
18971920
overwritten_params = transformer._overwritten_params
18981921
module_names = set()
18991922

@@ -2131,6 +2154,7 @@ def _get_weight_shape(weight: torch.Tensor):
21312154
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
21322155

21332156

2157+
21342158
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21352159
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
21362160
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):

0 commit comments

Comments
 (0)