Skip to content

Commit 7f59ca0

Browse files
committed
load metadata.
1 parent 7ec4ef4 commit 7f59ca0

File tree

6 files changed

+38
-73
lines changed

6 files changed

+38
-73
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464

6565
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
6666
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
67+
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
6768

6869

6970
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
@@ -208,7 +209,6 @@ def _fetch_state_dict(
208209
subfolder,
209210
user_agent,
210211
allow_pickle,
211-
load_with_metadata=False,
212212
):
213213
model_file = None
214214
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -226,8 +226,6 @@ def _fetch_state_dict(
226226
file_extension=".safetensors",
227227
local_files_only=local_files_only,
228228
)
229-
if load_with_metadata and not weight_name.endswith(".safetensors"):
230-
raise ValueError("`load_with_metadata` cannot be set to True when not using safetensors.")
231229

232230
model_file = _get_model_file(
233231
pretrained_model_name_or_path_or_dict,
@@ -242,10 +240,7 @@ def _fetch_state_dict(
242240
user_agent=user_agent,
243241
)
244242
state_dict = safetensors.torch.load_file(model_file, device="cpu")
245-
if load_with_metadata:
246-
state_dict = _maybe_populate_state_dict_with_metadata(
247-
state_dict, model_file, metadata_key="lora_adapter_metadata"
248-
)
243+
state_dict = _maybe_populate_state_dict_with_metadata(state_dict, model_file)
249244

250245
except (IOError, safetensors.SafetensorError) as e:
251246
if not allow_pickle:

src/diffusers/loaders/lora_pipeline.py

Lines changed: 17 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4727,7 +4727,6 @@ def lora_state_dict(
47274727
- A [torch state
47284728
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
47294729
4730-
load_with_metadata: TODO
47314730
cache_dir (`Union[str, os.PathLike]`, *optional*):
47324731
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
47334732
is not used.
@@ -4762,7 +4761,6 @@ def lora_state_dict(
47624761
subfolder = kwargs.pop("subfolder", None)
47634762
weight_name = kwargs.pop("weight_name", None)
47644763
use_safetensors = kwargs.pop("use_safetensors", None)
4765-
load_with_metadata = kwargs.pop("load_with_metadata", False)
47664764

47674765
allow_pickle = False
47684766
if use_safetensors is None:
@@ -4787,7 +4785,6 @@ def lora_state_dict(
47874785
subfolder=subfolder,
47884786
user_agent=user_agent,
47894787
allow_pickle=allow_pickle,
4790-
load_with_metadata=load_with_metadata,
47914788
)
47924789
if any(k.startswith("diffusion_model.") for k in state_dict):
47934790
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
@@ -4861,7 +4858,6 @@ def load_lora_weights(
48614858
raise ValueError("PEFT backend is required for this method.")
48624859

48634860
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4864-
load_with_metadata = kwargs.get("load_with_metadata", False)
48654861
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
48664862
raise ValueError(
48674863
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
@@ -4888,7 +4884,6 @@ def load_lora_weights(
48884884
adapter_name=adapter_name,
48894885
_pipeline=self,
48904886
low_cpu_mem_usage=low_cpu_mem_usage,
4891-
load_with_metadata=load_with_metadata,
48924887
hotswap=hotswap,
48934888
)
48944889

@@ -4902,54 +4897,25 @@ def load_lora_into_transformer(
49024897
_pipeline=None,
49034898
low_cpu_mem_usage=False,
49044899
hotswap: bool = False,
4905-
load_with_metadata: bool = False,
49064900
):
49074901
"""
4908-
This will load the LoRA layers specified in `state_dict` into `transformer`.
4909-
4910-
Parameters:
4911-
state_dict (`dict`):
4912-
A standard state dict containing the lora layer parameters. The keys can either be indexed
4913-
directly into the unet or prefixed with an additional `unet` which can be used to distinguish
4914-
between text encoder lora layers.
4915-
transformer (`WanTransformer3DModel`):
4916-
The Transformer model to load the LoRA layers into.
4917-
adapter_name (`str`, *optional*):
4918-
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4919-
`default_{i}` where i is the total number of adapters being loaded.
4920-
low_cpu_mem_usage (`bool`, *optional*):
4921-
Speed up model loading by only loading the pretrained LoRA weights and not initializing the
4922-
random weights.
4923-
<<<<<<< HEAD
4924-
hotswap : (`bool`, *optional*)
4925-
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded
4926-
adapter in-place. This means that, instead of loading an additional adapter, this will take the
4927-
existing adapter weights and replace them with the weights of the new adapter. This can be
4928-
faster and more memory efficient. However, the main advantage of hotswapping is that when the
4929-
model is compiled with torch.compile, loading the new adapter does not require recompilation of
4930-
the model. When using hotswapping, the passed `adapter_name` should be the name of an already
4931-
loaded adapter.
4932-
4933-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling),
4934-
you need to call an additional method before loading the adapter:
4935-
4936-
```py
4937-
pipeline = ... # load diffusers pipeline
4938-
max_rank = ... # the highest rank among all LoRAs that you want to load
4939-
# call *before* compiling and loading the LoRA adapter
4940-
pipeline.enable_lora_hotswap(target_rank=max_rank)
4941-
pipeline.load_lora_weights(file_name)
4942-
# optionally compile the model now
4943-
```
4944-
4945-
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4946-
limitations to this technique, which are documented here:
4947-
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4948-
load_with_metadata: TODO
4949-
=======
4950-
hotswap (`bool`, *optional*):
4951-
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
4952-
>>>>>>> main
4902+
This will load the LoRA layers specified in `state_dict` into `transformer`.
4903+
4904+
Parameters:
4905+
state_dict (`dict`):
4906+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4907+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4908+
encoder lora layers.
4909+
transformer (`WanTransformer3DModel`):
4910+
The Transformer model to load the LoRA layers into.
4911+
adapter_name (`str`, *optional*):
4912+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4913+
`default_{i}` where i is the total number of adapters being loaded.
4914+
low_cpu_mem_usage (`bool`, *optional*):
4915+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4916+
weights.
4917+
hotswap (`bool`, *optional*):
4918+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
49534919
"""
49544920
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
49554921
raise ValueError(
@@ -4965,7 +4931,6 @@ def load_lora_into_transformer(
49654931
_pipeline=_pipeline,
49664932
low_cpu_mem_usage=low_cpu_mem_usage,
49674933
hotswap=hotswap,
4968-
load_with_metadata=load_with_metadata,
49694934
)
49704935

49714936
@classmethod

src/diffusers/loaders/peft.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def load_lora_adapter(
120120
pretrained_model_name_or_path_or_dict,
121121
prefix="transformer",
122122
hotswap: bool = False,
123-
load_with_metadata: bool = False,
124123
**kwargs,
125124
):
126125
r"""
@@ -190,7 +189,6 @@ def load_lora_adapter(
190189
limitations to this technique, which are documented here:
191190
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
192191
193-
load_with_metadata: TODO
194192
"""
195193
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
196194
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -233,7 +231,6 @@ def load_lora_adapter(
233231
subfolder=subfolder,
234232
user_agent=user_agent,
235233
allow_pickle=allow_pickle,
236-
load_with_metadata=load_with_metadata,
237234
)
238235
if network_alphas is not None and prefix is None:
239236
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
@@ -280,7 +277,6 @@ def load_lora_adapter(
280277
rank,
281278
network_alpha_dict=network_alphas,
282279
peft_state_dict=state_dict,
283-
load_with_metadata=load_with_metadata,
284280
prefix=prefix,
285281
)
286282
_maybe_raise_error_for_ambiguity(lora_config_kwargs)

src/diffusers/utils/peft_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,16 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
148148

149149

150150
def get_peft_kwargs(
151-
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False
151+
rank_dict,
152+
network_alpha_dict,
153+
peft_state_dict,
154+
is_unet=True,
155+
prefix=None,
152156
):
153-
if load_with_metadata:
154-
if "lora_adapter_metadata" not in peft_state_dict:
155-
raise ValueError("Couldn't find 'lora_adapter_metadata' key in the `peft_state_dict`.")
156-
metadata = peft_state_dict["lora_adapter_metadata"]
157+
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
158+
159+
if LORA_ADAPTER_METADATA_KEY in peft_state_dict:
160+
metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY]
157161
if prefix is not None:
158162
metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()}
159163
return metadata

src/diffusers/utils/state_dict_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,15 @@ def state_dict_all_zero(state_dict, filter_str=None):
350350
return all(torch.all(param == 0).item() for param in state_dict.values())
351351

352352

353-
def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_key):
353+
def _maybe_populate_state_dict_with_metadata(state_dict, model_file):
354+
if not model_file.endswith(".safetensors"):
355+
return state_dict
356+
354357
import safetensors.torch
355358

359+
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
360+
361+
metadata_key = LORA_ADAPTER_METADATA_KEY
356362
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
357363
if hasattr(f, "metadata"):
358364
metadata = f.metadata()
@@ -361,6 +367,4 @@ def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_ke
361367
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"):
362368
peft_metadata = {k: v for k, v in metadata.items() if k != "format"}
363369
state_dict["lora_adapter_metadata"] = json.loads(peft_metadata[metadata_key])
364-
else:
365-
raise ValueError("Metadata couldn't be parsed from the safetensors file.")
366370
return state_dict

tests/lora/test_lora_layers_wan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
WanPipeline,
2727
WanTransformer3DModel,
2828
)
29+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
2930
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
3031

3132

@@ -162,9 +163,9 @@ def test_adapter_metadata_is_loaded_correctly(self):
162163
pipe.unload_lora_weights()
163164
state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True)
164165

165-
self.assertTrue("lora_metadata" in state_dict)
166+
self.assertTrue(LORA_ADAPTER_METADATA_KEY in state_dict)
166167

167-
parsed_metadata = state_dict["lora_metadata"]
168+
parsed_metadata = state_dict[LORA_ADAPTER_METADATA_KEY]
168169
parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()}
169170
check_if_dicts_are_equal(parsed_metadata, metadata)
170171

0 commit comments

Comments
 (0)