Skip to content

Commit 7bb6c9f

Browse files
committed
propagate changes.
1 parent aa5cb3c commit 7bb6c9f

File tree

6 files changed

+140
-84
lines changed

6 files changed

+140
-84
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,11 @@ def _load_lora_into_text_encoder(
353353
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
354354

355355
# Load the layers corresponding to text encoder and make necessary adjustments.
356+
if LORA_ADAPTER_METADATA_KEY in state_dict:
357+
metadata = state_dict[LORA_ADAPTER_METADATA_KEY]
356358
if prefix is not None:
357359
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
360+
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata
358361

359362
if len(state_dict) > 0:
360363
logger.info(f"Loading {prefix}.")
@@ -382,7 +385,7 @@ def _load_lora_into_text_encoder(
382385
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
383386
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
384387

385-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
388+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix)
386389

387390
if "use_dora" in lora_config_kwargs:
388391
if lora_config_kwargs["use_dora"]:

src/diffusers/loaders/lora_pipeline.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,9 @@ def load_lora_weights(
644644
if not is_correct_format:
645645
raise ValueError("Invalid LoRA checkpoint.")
646646

647+
from .lora_base import LORA_ADAPTER_METADATA_KEY
648+
649+
print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before UNet")
647650
self.load_lora_into_unet(
648651
state_dict,
649652
network_alphas=network_alphas,
@@ -653,6 +656,7 @@ def load_lora_weights(
653656
low_cpu_mem_usage=low_cpu_mem_usage,
654657
hotswap=hotswap,
655658
)
659+
print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before text encoder.")
656660
self.load_lora_into_text_encoder(
657661
state_dict,
658662
network_alphas=network_alphas,
@@ -664,6 +668,7 @@ def load_lora_weights(
664668
low_cpu_mem_usage=low_cpu_mem_usage,
665669
hotswap=hotswap,
666670
)
671+
print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before text encoder 2.")
667672
self.load_lora_into_text_encoder(
668673
state_dict,
669674
network_alphas=network_alphas,
@@ -732,6 +737,7 @@ def lora_state_dict(
732737
"""
733738
# Load the main state dict first which has the LoRA layers for either of
734739
# UNet and text encoder or both.
740+
735741
cache_dir = kwargs.pop("cache_dir", None)
736742
force_download = kwargs.pop("force_download", False)
737743
proxies = kwargs.pop("proxies", None)
@@ -914,6 +920,9 @@ def save_lora_weights(
914920
weight_name: str = None,
915921
save_function: Callable = None,
916922
safe_serialization: bool = True,
923+
unet_lora_adapter_metadata=None,
924+
text_encoder_lora_adapter_metadata=None,
925+
text_encoder_2_lora_adapter_metadata=None,
917926
):
918927
r"""
919928
Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -939,8 +948,12 @@ def save_lora_weights(
939948
`DIFFUSERS_SAVE_MODE`.
940949
safe_serialization (`bool`, *optional*, defaults to `True`):
941950
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
951+
unet_lora_adapter_metadata: TODO
952+
text_encoder_lora_adapter_metadata: TODO
953+
text_encoder_2_lora_adapter_metadata: TODO
942954
"""
943955
state_dict = {}
956+
lora_adapter_metadata = {}
944957

945958
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
946959
raise ValueError(
@@ -956,13 +969,23 @@ def save_lora_weights(
956969
if text_encoder_2_lora_layers:
957970
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
958971

972+
if unet_lora_adapter_metadata is not None:
973+
lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, cls.unet_name))
974+
975+
if text_encoder_lora_adapter_metadata:
976+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
977+
978+
if text_encoder_2_lora_adapter_metadata:
979+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2"))
980+
959981
cls.write_lora_layers(
960982
state_dict=state_dict,
961983
save_directory=save_directory,
962984
is_main_process=is_main_process,
963985
weight_name=weight_name,
964986
save_function=save_function,
965987
safe_serialization=safe_serialization,
988+
lora_adapter_metadata=lora_adapter_metadata,
966989
)
967990

968991
def fuse_lora(

src/diffusers/loaders/peft.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def load_lora_adapter(
193193
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
194194
from peft.tuners.tuners_utils import BaseTunerLayer
195195

196-
from .lora_base import LORA_ADAPTER_METADATA_KEY
196+
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
197197

198198
cache_dir = kwargs.pop("cache_dir", None)
199199
force_download = kwargs.pop("force_download", False)
@@ -234,15 +234,14 @@ def load_lora_adapter(
234234
user_agent=user_agent,
235235
allow_pickle=allow_pickle,
236236
)
237+
if LORA_ADAPTER_METADATA_KEY in state_dict:
238+
metadata = state_dict[LORA_ADAPTER_METADATA_KEY]
237239
if network_alphas is not None and prefix is None:
238240
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
239241

240242
if prefix is not None:
241-
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
242-
243-
metadata = state_dict.pop(LORA_ADAPTER_METADATA_KEY, None)
244-
if metadata is not None:
245-
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata
243+
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
244+
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata
246245

247246
if len(state_dict) > 0:
248247
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:

src/diffusers/utils/peft_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,12 @@ def get_peft_kwargs(
158158

159159
if LORA_ADAPTER_METADATA_KEY in peft_state_dict:
160160
metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY]
161-
if metadata:
162-
if prefix is not None:
163-
metadata = {k.removeprefix(prefix + "."): v for k, v in metadata.items()}
164-
return metadata
161+
else:
162+
metadata = None
163+
if metadata:
164+
if prefix is not None:
165+
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
166+
return metadata
165167

166168
rank_pattern = {}
167169
alpha_pattern = {}

tests/lora/test_lora_layers_wan.py

Lines changed: 1 addition & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# limitations under the License.
1414

1515
import sys
16-
import tempfile
1716
import unittest
1817

19-
import numpy as np
2018
import torch
2119
from transformers import AutoTokenizer, T5EncoderModel
2220

@@ -26,13 +24,7 @@
2624
WanPipeline,
2725
WanTransformer3DModel,
2826
)
29-
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
30-
from diffusers.utils.testing_utils import (
31-
check_if_dicts_are_equal,
32-
floats_tensor,
33-
require_peft_backend,
34-
torch_device,
35-
)
27+
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
3628

3729

3830
sys.path.append(".")
@@ -145,64 +137,3 @@ def test_simple_inference_with_text_lora_fused(self):
145137
@unittest.skip("Text encoder LoRA is not supported in Wan.")
146138
def test_simple_inference_with_text_lora_save_load(self):
147139
pass
148-
149-
def test_lora_adapter_metadata_is_loaded_correctly(self):
150-
# TODO: Will write the test in utils.py eventually.
151-
scheduler_cls = self.scheduler_classes[0]
152-
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
153-
pipe = self.pipeline_class(**components)
154-
155-
pipe, _ = self.check_if_adapters_added_correctly(
156-
pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config
157-
)
158-
159-
with tempfile.TemporaryDirectory() as tmpdir:
160-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
161-
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
162-
metadata = denoiser_lora_config.to_dict()
163-
self.pipeline_class.save_lora_weights(
164-
save_directory=tmpdir,
165-
transformer_lora_adapter_metadata=metadata,
166-
**lora_state_dicts,
167-
)
168-
pipe.unload_lora_weights()
169-
state_dict = pipe.lora_state_dict(tmpdir)
170-
171-
self.assertTrue(LORA_ADAPTER_METADATA_KEY in state_dict)
172-
173-
parsed_metadata = state_dict[LORA_ADAPTER_METADATA_KEY]
174-
parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()}
175-
check_if_dicts_are_equal(parsed_metadata, metadata)
176-
177-
def test_lora_adapter_metadata_save_load_inference(self):
178-
# Will write the test in utils.py eventually.
179-
scheduler_cls = self.scheduler_classes[0]
180-
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
181-
pipe = self.pipeline_class(**components).to(torch_device)
182-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
183-
184-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
185-
self.assertTrue(output_no_lora.shape == self.output_shape)
186-
187-
pipe, _ = self.check_if_adapters_added_correctly(
188-
pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config
189-
)
190-
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
191-
192-
with tempfile.TemporaryDirectory() as tmpdir:
193-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
194-
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
195-
metadata = denoiser_lora_config.to_dict()
196-
self.pipeline_class.save_lora_weights(
197-
save_directory=tmpdir,
198-
transformer_lora_adapter_metadata=metadata,
199-
**lora_state_dicts,
200-
)
201-
pipe.unload_lora_weights()
202-
pipe.load_lora_weights(tmpdir)
203-
204-
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
205-
206-
self.assertTrue(
207-
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
208-
)

0 commit comments

Comments
 (0)