@@ -644,6 +644,9 @@ def load_lora_weights(
644644 if not is_correct_format :
645645 raise ValueError ("Invalid LoRA checkpoint." )
646646
647+ from .lora_base import LORA_ADAPTER_METADATA_KEY
648+
649+ print (f"{ LORA_ADAPTER_METADATA_KEY in state_dict = } before UNet" )
647650 self .load_lora_into_unet (
648651 state_dict ,
649652 network_alphas = network_alphas ,
@@ -653,6 +656,7 @@ def load_lora_weights(
653656 low_cpu_mem_usage = low_cpu_mem_usage ,
654657 hotswap = hotswap ,
655658 )
659+ print (f"{ LORA_ADAPTER_METADATA_KEY in state_dict = } before text encoder." )
656660 self .load_lora_into_text_encoder (
657661 state_dict ,
658662 network_alphas = network_alphas ,
@@ -664,6 +668,7 @@ def load_lora_weights(
664668 low_cpu_mem_usage = low_cpu_mem_usage ,
665669 hotswap = hotswap ,
666670 )
671+ print (f"{ LORA_ADAPTER_METADATA_KEY in state_dict = } before text encoder 2." )
667672 self .load_lora_into_text_encoder (
668673 state_dict ,
669674 network_alphas = network_alphas ,
@@ -732,6 +737,7 @@ def lora_state_dict(
732737 """
733738 # Load the main state dict first which has the LoRA layers for either of
734739 # UNet and text encoder or both.
740+
735741 cache_dir = kwargs .pop ("cache_dir" , None )
736742 force_download = kwargs .pop ("force_download" , False )
737743 proxies = kwargs .pop ("proxies" , None )
@@ -914,6 +920,9 @@ def save_lora_weights(
914920 weight_name : str = None ,
915921 save_function : Callable = None ,
916922 safe_serialization : bool = True ,
923+ unet_lora_adapter_metadata = None ,
924+ text_encoder_lora_adapter_metadata = None ,
925+ text_encoder_2_lora_adapter_metadata = None ,
917926 ):
918927 r"""
919928 Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -939,8 +948,12 @@ def save_lora_weights(
939948 `DIFFUSERS_SAVE_MODE`.
940949 safe_serialization (`bool`, *optional*, defaults to `True`):
941950 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
951+ unet_lora_adapter_metadata: TODO
952+ text_encoder_lora_adapter_metadata: TODO
953+ text_encoder_2_lora_adapter_metadata: TODO
942954 """
943955 state_dict = {}
956+ lora_adapter_metadata = {}
944957
945958 if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers ):
946959 raise ValueError (
@@ -956,13 +969,23 @@ def save_lora_weights(
956969 if text_encoder_2_lora_layers :
957970 state_dict .update (cls .pack_weights (text_encoder_2_lora_layers , "text_encoder_2" ))
958971
972+ if unet_lora_adapter_metadata is not None :
973+ lora_adapter_metadata .update (cls .pack_weights (unet_lora_adapter_metadata , cls .unet_name ))
974+
975+ if text_encoder_lora_adapter_metadata :
976+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_lora_adapter_metadata , cls .text_encoder_name ))
977+
978+ if text_encoder_2_lora_adapter_metadata :
979+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_2_lora_adapter_metadata , "text_encoder_2" ))
980+
959981 cls .write_lora_layers (
960982 state_dict = state_dict ,
961983 save_directory = save_directory ,
962984 is_main_process = is_main_process ,
963985 weight_name = weight_name ,
964986 save_function = save_function ,
965987 safe_serialization = safe_serialization ,
988+ lora_adapter_metadata = lora_adapter_metadata ,
966989 )
967990
968991 def fuse_lora (
0 commit comments