@@ -737,7 +737,6 @@ def lora_state_dict(
737737 """
738738 # Load the main state dict first which has the LoRA layers for either of
739739 # UNet and text encoder or both.
740-
741740 cache_dir = kwargs .pop ("cache_dir" , None )
742741 force_download = kwargs .pop ("force_download" , False )
743742 proxies = kwargs .pop ("proxies" , None )
@@ -1355,6 +1354,9 @@ def save_lora_weights(
13551354 weight_name : str = None ,
13561355 save_function : Callable = None ,
13571356 safe_serialization : bool = True ,
1357+ transformer_lora_adapter_metadata = None ,
1358+ text_encoder_lora_adapter_metadata = None ,
1359+ text_encoder_2_lora_adapter_metadata = None ,
13581360 ):
13591361 r"""
13601362 Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -1380,8 +1382,12 @@ def save_lora_weights(
13801382 `DIFFUSERS_SAVE_MODE`.
13811383 safe_serialization (`bool`, *optional*, defaults to `True`):
13821384 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1385+ transformer_lora_adapter_metadata: TODO
1386+ text_encoder_lora_adapter_metadata: TODO
1387+ text_encoder_2_lora_adapter_metadata: TODO
13831388 """
13841389 state_dict = {}
1390+ lora_adapter_metadata = {}
13851391
13861392 if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers ):
13871393 raise ValueError (
@@ -1397,13 +1403,23 @@ def save_lora_weights(
13971403 if text_encoder_2_lora_layers :
13981404 state_dict .update (cls .pack_weights (text_encoder_2_lora_layers , "text_encoder_2" ))
13991405
1406+ if transformer_lora_adapter_metadata is not None :
1407+ lora_adapter_metadata .update (cls .pack_weights (transformer_lora_adapter_metadata , cls .transformer_name ))
1408+
1409+ if text_encoder_lora_adapter_metadata :
1410+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_lora_adapter_metadata , cls .text_encoder_name ))
1411+
1412+ if text_encoder_2_lora_adapter_metadata :
1413+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_2_lora_adapter_metadata , "text_encoder_2" ))
1414+
14001415 cls .write_lora_layers (
14011416 state_dict = state_dict ,
14021417 save_directory = save_directory ,
14031418 is_main_process = is_main_process ,
14041419 weight_name = weight_name ,
14051420 save_function = save_function ,
14061421 safe_serialization = safe_serialization ,
1422+ lora_adapter_metadata = lora_adapter_metadata ,
14071423 )
14081424
14091425 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
0 commit comments