Skip to content

Commit ce5ae9e

Browse files
authored
Update lora_pipeline.py
1 parent 5466c53 commit ce5ae9e

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ def load_lora_weights(
15501550
)
15511551

15521552
if has_param_with_expanded_shape:
1553-
if not hasattr(self,"_lora_unloading_reset_list"):
1553+
if not hasattr(self,"_lora_unloading_reset_list",None):
15541554
self._lora_unloading_reset_list = [adapter_name]
15551555
else:
15561556
self._lora_unloading_reset_list.append(adapter_name)
@@ -1842,7 +1842,7 @@ def fuse_lora(
18421842
logger.info(
18431843
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
18441844
"as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
1845-
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
1845+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
18461846
)
18471847

18481848
super().fuse_lora(
@@ -1900,13 +1900,16 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19001900

19011901
if reset_to_overwritten_params and transformer is not None:
19021902
self._maybe_reset_transformer(transformer)
1903-
if getattr(self,"_lora_unloading_reset_list",None):
1904-
self._lora_unloading_reset_list.clear()
1903+
if not getattr(self, "_lora_unloading_reset_list", None):
1904+
return
1905+
self._lora_unloading_reset_list.clear()
19051906

19061907
def delete_adapters(self, adapter_names: Union[List[str], str]):
19071908
super().delete_adapters(adapter_names)
19081909

19091910
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
1911+
if not getattr(self, "_lora_unloading_reset_list", None):
1912+
return
19101913
for adapter_name in adapter_names:
19111914
if adapter_name in self._lora_unloading_reset_list:
19121915
self._lora_unloading_reset_list.remove(adapter_name)

0 commit comments

Comments
 (0)