Skip to content

Commit 425ea95

Browse files
committed
fix-copies.
1 parent 4304a6d commit 425ea95

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ def lora_state_dict(
764764
The subfolder location of a model file within a larger model repository on the Hub or locally.
765765
weight_name (`str`, *optional*, defaults to None):
766766
Name of the serialized state dict file.
767+
return_lora_metadata: TODO
767768
"""
768769
# Load the main state dict first which has the LoRA layers for either of
769770
# UNet and text encoder or both.
@@ -777,6 +778,7 @@ def lora_state_dict(
777778
weight_name = kwargs.pop("weight_name", None)
778779
unet_config = kwargs.pop("unet_config", None)
779780
use_safetensors = kwargs.pop("use_safetensors", None)
781+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
780782

781783
allow_pickle = False
782784
if use_safetensors is None:
@@ -788,7 +790,7 @@ def lora_state_dict(
788790
"framework": "pytorch",
789791
}
790792

791-
state_dict = _fetch_state_dict(
793+
state_dict, metadata = _fetch_state_dict(
792794
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
793795
weight_name=weight_name,
794796
use_safetensors=use_safetensors,
@@ -825,7 +827,8 @@ def lora_state_dict(
825827
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
826828
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
827829

828-
return state_dict, network_alphas
830+
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
831+
return out
829832

830833
@classmethod
831834
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
@@ -835,6 +838,7 @@ def load_lora_into_unet(
835838
network_alphas,
836839
unet,
837840
adapter_name=None,
841+
metadata=None,
838842
_pipeline=None,
839843
low_cpu_mem_usage=False,
840844
hotswap: bool = False,
@@ -879,6 +883,7 @@ def load_lora_into_unet(
879883
prefix=cls.unet_name,
880884
network_alphas=network_alphas,
881885
adapter_name=adapter_name,
886+
metadata=metadata,
882887
_pipeline=_pipeline,
883888
low_cpu_mem_usage=low_cpu_mem_usage,
884889
hotswap=hotswap,
@@ -894,6 +899,7 @@ def load_lora_into_text_encoder(
894899
prefix=None,
895900
lora_scale=1.0,
896901
adapter_name=None,
902+
metadata=None,
897903
_pipeline=None,
898904
low_cpu_mem_usage=False,
899905
hotswap: bool = False,
@@ -919,6 +925,7 @@ def load_lora_into_text_encoder(
919925
adapter_name (`str`, *optional*):
920926
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
921927
`default_{i}` where i is the total number of adapters being loaded.
928+
metadata: TODO
922929
low_cpu_mem_usage (`bool`, *optional*):
923930
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
924931
weights.
@@ -933,6 +940,7 @@ def load_lora_into_text_encoder(
933940
prefix=prefix,
934941
text_encoder_name=cls.text_encoder_name,
935942
adapter_name=adapter_name,
943+
metadata=metadata,
936944
_pipeline=_pipeline,
937945
low_cpu_mem_usage=low_cpu_mem_usage,
938946
hotswap=hotswap,
@@ -1331,6 +1339,7 @@ def load_lora_into_text_encoder(
13311339
prefix=None,
13321340
lora_scale=1.0,
13331341
adapter_name=None,
1342+
metadata=None,
13341343
_pipeline=None,
13351344
low_cpu_mem_usage=False,
13361345
hotswap: bool = False,
@@ -1356,6 +1365,7 @@ def load_lora_into_text_encoder(
13561365
adapter_name (`str`, *optional*):
13571366
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
13581367
`default_{i}` where i is the total number of adapters being loaded.
1368+
metadata: TODO
13591369
low_cpu_mem_usage (`bool`, *optional*):
13601370
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
13611371
weights.
@@ -1370,6 +1380,7 @@ def load_lora_into_text_encoder(
13701380
prefix=prefix,
13711381
text_encoder_name=cls.text_encoder_name,
13721382
adapter_name=adapter_name,
1383+
metadata=metadata,
13731384
_pipeline=_pipeline,
13741385
low_cpu_mem_usage=low_cpu_mem_usage,
13751386
hotswap=hotswap,
@@ -2237,6 +2248,7 @@ def load_lora_into_text_encoder(
22372248
prefix=None,
22382249
lora_scale=1.0,
22392250
adapter_name=None,
2251+
metadata=None,
22402252
_pipeline=None,
22412253
low_cpu_mem_usage=False,
22422254
hotswap: bool = False,
@@ -2262,6 +2274,7 @@ def load_lora_into_text_encoder(
22622274
adapter_name (`str`, *optional*):
22632275
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
22642276
`default_{i}` where i is the total number of adapters being loaded.
2277+
metadata: TODO
22652278
low_cpu_mem_usage (`bool`, *optional*):
22662279
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
22672280
weights.
@@ -2276,6 +2289,7 @@ def load_lora_into_text_encoder(
22762289
prefix=prefix,
22772290
text_encoder_name=cls.text_encoder_name,
22782291
adapter_name=adapter_name,
2292+
metadata=metadata,
22792293
_pipeline=_pipeline,
22802294
low_cpu_mem_usage=low_cpu_mem_usage,
22812295
hotswap=hotswap,
@@ -2333,11 +2347,13 @@ def save_lora_weights(
23332347
if text_encoder_lora_layers:
23342348
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
23352349

2336-
if transformer_lora_adapter_metadata is not None:
2337-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
2350+
if transformer_lora_adapter_metadata:
2351+
lora_adapter_metadata.update(_pack_sd_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name))
23382352

23392353
if text_encoder_lora_adapter_metadata:
2340-
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
2354+
lora_adapter_metadata.update(
2355+
_pack_sd_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
2356+
)
23412357

23422358
# Save the model
23432359
cls.write_lora_layers(
@@ -2769,6 +2785,7 @@ def load_lora_into_text_encoder(
27692785
prefix=None,
27702786
lora_scale=1.0,
27712787
adapter_name=None,
2788+
metadata=None,
27722789
_pipeline=None,
27732790
low_cpu_mem_usage=False,
27742791
hotswap: bool = False,
@@ -2794,6 +2811,7 @@ def load_lora_into_text_encoder(
27942811
adapter_name (`str`, *optional*):
27952812
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
27962813
`default_{i}` where i is the total number of adapters being loaded.
2814+
metadata: TODO
27972815
low_cpu_mem_usage (`bool`, *optional*):
27982816
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
27992817
weights.
@@ -2808,6 +2826,7 @@ def load_lora_into_text_encoder(
28082826
prefix=prefix,
28092827
text_encoder_name=cls.text_encoder_name,
28102828
adapter_name=adapter_name,
2829+
metadata=metadata,
28112830
_pipeline=_pipeline,
28122831
low_cpu_mem_usage=low_cpu_mem_usage,
28132832
hotswap=hotswap,

0 commit comments

Comments
 (0)