Skip to content

Commit 970abcb

Browse files
authored
Update lora_pipeline.py
1 parent 2c1ed50 commit 970abcb

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,11 @@ def load_lora_weights(
15501550
)
15511551

15521552
if has_param_with_expanded_shape:
1553+
if not hasattr(self,"_unloading_reset_list"):
1554+
self._lora_unloading_reset_list = [adapter_name]
1555+
else:
1556+
self._lora_unloading_reset_list.append(adapter_name)
1557+
15531558
logger.info(
15541559
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
15551560
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
@@ -1893,7 +1898,25 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
18931898
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
18941899
transformer._transformer_norm_layers = None
18951900

1896-
if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
1901+
if reset_to_overwritten_params and transformer is not None:
1902+
self._maybe_reset_transformer(transformer)
1903+
self._lora_unloading_reset_list.clear()
1904+
1905+
def delete_adapters(self, adapter_names: Union[List[str], str]):
1906+
super().delete_adapters(adapter_names)
1907+
1908+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
1909+
for adapter_name in adapter_names:
1910+
if adapter_name in self._lora_unloading_reset_list:
1911+
self._lora_unloading_reset_list.remove(adapter_name)
1912+
# If more than 1 LoRA adapters expanded the transformer, we don't need to resest the transformer.
1913+
if len(self._lora_unloading_reset_list) == 0:
1914+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1915+
self._maybe_reset_transformer(transformer)
1916+
1917+
@classmethod
1918+
def _maybe_reset_transformer(cls, transformer: torch.nn.Module):
1919+
if getattr(transformer, "_overwritten_params", None) is not None:
18971920
overwritten_params = transformer._overwritten_params
18981921
module_names = set()
18991922

@@ -1903,7 +1926,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19031926

19041927
for name, module in transformer.named_modules():
19051928
if isinstance(module, torch.nn.Linear) and name in module_names:
1906-
module_weight = module.weight.data
19071929
module_bias = module.bias.data if module.bias is not None else None
19081930
bias = module_bias is not None
19091931

@@ -1917,7 +1939,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19171939
in_features,
19181940
out_features,
19191941
bias=bias,
1920-
dtype=module_weight.dtype,
19211942
)
19221943

19231944
tmp_state_dict = {"weight": current_param_weight}

0 commit comments

Comments
 (0)