@@ -764,6 +764,7 @@ def lora_state_dict(
764764 The subfolder location of a model file within a larger model repository on the Hub or locally.
765765 weight_name (`str`, *optional*, defaults to None):
766766 Name of the serialized state dict file.
767+ return_lora_metadata: TODO
767768 """
768769 # Load the main state dict first which has the LoRA layers for either of
769770 # UNet and text encoder or both.
@@ -777,6 +778,7 @@ def lora_state_dict(
777778 weight_name = kwargs .pop ("weight_name" , None )
778779 unet_config = kwargs .pop ("unet_config" , None )
779780 use_safetensors = kwargs .pop ("use_safetensors" , None )
781+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
780782
781783 allow_pickle = False
782784 if use_safetensors is None :
@@ -788,7 +790,7 @@ def lora_state_dict(
788790 "framework" : "pytorch" ,
789791 }
790792
791- state_dict = _fetch_state_dict (
793+ state_dict , metadata = _fetch_state_dict (
792794 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
793795 weight_name = weight_name ,
794796 use_safetensors = use_safetensors ,
@@ -825,7 +827,8 @@ def lora_state_dict(
825827 state_dict = _maybe_map_sgm_blocks_to_diffusers (state_dict , unet_config )
826828 state_dict , network_alphas = _convert_non_diffusers_lora_to_diffusers (state_dict )
827829
828- return state_dict , network_alphas
830+ out = (state_dict , network_alphas , metadata ) if return_lora_metadata else (state_dict , network_alphas )
831+ return out
829832
830833 @classmethod
831834 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
@@ -835,6 +838,7 @@ def load_lora_into_unet(
835838 network_alphas ,
836839 unet ,
837840 adapter_name = None ,
841+ metadata = None ,
838842 _pipeline = None ,
839843 low_cpu_mem_usage = False ,
840844 hotswap : bool = False ,
@@ -879,6 +883,7 @@ def load_lora_into_unet(
879883 prefix = cls .unet_name ,
880884 network_alphas = network_alphas ,
881885 adapter_name = adapter_name ,
886+ metadata = metadata ,
882887 _pipeline = _pipeline ,
883888 low_cpu_mem_usage = low_cpu_mem_usage ,
884889 hotswap = hotswap ,
@@ -894,6 +899,7 @@ def load_lora_into_text_encoder(
894899 prefix = None ,
895900 lora_scale = 1.0 ,
896901 adapter_name = None ,
902+ metadata = None ,
897903 _pipeline = None ,
898904 low_cpu_mem_usage = False ,
899905 hotswap : bool = False ,
@@ -919,6 +925,7 @@ def load_lora_into_text_encoder(
919925 adapter_name (`str`, *optional*):
920926 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
921927 `default_{i}` where i is the total number of adapters being loaded.
928+ metadata: TODO
922929 low_cpu_mem_usage (`bool`, *optional*):
923930 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
924931 weights.
@@ -933,6 +940,7 @@ def load_lora_into_text_encoder(
933940 prefix = prefix ,
934941 text_encoder_name = cls .text_encoder_name ,
935942 adapter_name = adapter_name ,
943+ metadata = metadata ,
936944 _pipeline = _pipeline ,
937945 low_cpu_mem_usage = low_cpu_mem_usage ,
938946 hotswap = hotswap ,
@@ -1331,6 +1339,7 @@ def load_lora_into_text_encoder(
13311339 prefix = None ,
13321340 lora_scale = 1.0 ,
13331341 adapter_name = None ,
1342+ metadata = None ,
13341343 _pipeline = None ,
13351344 low_cpu_mem_usage = False ,
13361345 hotswap : bool = False ,
@@ -1356,6 +1365,7 @@ def load_lora_into_text_encoder(
13561365 adapter_name (`str`, *optional*):
13571366 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
13581367 `default_{i}` where i is the total number of adapters being loaded.
1368+ metadata: TODO
13591369 low_cpu_mem_usage (`bool`, *optional*):
13601370 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
13611371 weights.
@@ -1370,6 +1380,7 @@ def load_lora_into_text_encoder(
13701380 prefix = prefix ,
13711381 text_encoder_name = cls .text_encoder_name ,
13721382 adapter_name = adapter_name ,
1383+ metadata = metadata ,
13731384 _pipeline = _pipeline ,
13741385 low_cpu_mem_usage = low_cpu_mem_usage ,
13751386 hotswap = hotswap ,
@@ -2237,6 +2248,7 @@ def load_lora_into_text_encoder(
22372248 prefix = None ,
22382249 lora_scale = 1.0 ,
22392250 adapter_name = None ,
2251+ metadata = None ,
22402252 _pipeline = None ,
22412253 low_cpu_mem_usage = False ,
22422254 hotswap : bool = False ,
@@ -2262,6 +2274,7 @@ def load_lora_into_text_encoder(
22622274 adapter_name (`str`, *optional*):
22632275 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
22642276 `default_{i}` where i is the total number of adapters being loaded.
2277+ metadata: TODO
22652278 low_cpu_mem_usage (`bool`, *optional*):
22662279 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
22672280 weights.
@@ -2276,6 +2289,7 @@ def load_lora_into_text_encoder(
22762289 prefix = prefix ,
22772290 text_encoder_name = cls .text_encoder_name ,
22782291 adapter_name = adapter_name ,
2292+ metadata = metadata ,
22792293 _pipeline = _pipeline ,
22802294 low_cpu_mem_usage = low_cpu_mem_usage ,
22812295 hotswap = hotswap ,
@@ -2333,11 +2347,13 @@ def save_lora_weights(
23332347 if text_encoder_lora_layers :
23342348 state_dict .update (cls .pack_weights (text_encoder_lora_layers , cls .text_encoder_name ))
23352349
2336- if transformer_lora_adapter_metadata is not None :
2337- lora_adapter_metadata .update (cls . pack_weights (transformer_lora_adapter_metadata , cls .transformer_name ))
2350+ if transformer_lora_adapter_metadata :
2351+ lora_adapter_metadata .update (_pack_sd_with_prefix (transformer_lora_adapter_metadata , cls .transformer_name ))
23382352
23392353 if text_encoder_lora_adapter_metadata :
2340- lora_adapter_metadata .update (cls .pack_weights (text_encoder_lora_adapter_metadata , cls .text_encoder_name ))
2354+ lora_adapter_metadata .update (
2355+ _pack_sd_with_prefix (text_encoder_lora_adapter_metadata , cls .text_encoder_name )
2356+ )
23412357
23422358 # Save the model
23432359 cls .write_lora_layers (
@@ -2769,6 +2785,7 @@ def load_lora_into_text_encoder(
27692785 prefix = None ,
27702786 lora_scale = 1.0 ,
27712787 adapter_name = None ,
2788+ metadata = None ,
27722789 _pipeline = None ,
27732790 low_cpu_mem_usage = False ,
27742791 hotswap : bool = False ,
@@ -2794,6 +2811,7 @@ def load_lora_into_text_encoder(
27942811 adapter_name (`str`, *optional*):
27952812 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
27962813 `default_{i}` where i is the total number of adapters being loaded.
2814+ metadata: TODO
27972815 low_cpu_mem_usage (`bool`, *optional*):
27982816 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
27992817 weights.
@@ -2808,6 +2826,7 @@ def load_lora_into_text_encoder(
28082826 prefix = prefix ,
28092827 text_encoder_name = cls .text_encoder_name ,
28102828 adapter_name = adapter_name ,
2829+ metadata = metadata ,
28112830 _pipeline = _pipeline ,
28122831 low_cpu_mem_usage = low_cpu_mem_usage ,
28132832 hotswap = hotswap ,
0 commit comments