@@ -457,6 +457,8 @@ def save_lora_weights(
457457 weight_name : str = None ,
458458 save_function : Callable = None ,
459459 safe_serialization : bool = True ,
460+ unet_lora_adapter_metadata = None ,
461+ text_encoder_lora_adapter_metadata = None ,
460462 ):
461463 r"""
462464 Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -479,8 +481,11 @@ def save_lora_weights(
479481 `DIFFUSERS_SAVE_MODE`.
480482 safe_serialization (`bool`, *optional*, defaults to `True`):
481483 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
484+ unet_lora_adapter_metadata: TODO
485+ text_encoder_lora_adapter_metadata: TODO
482486 """
483487 state_dict = {}
488+ lora_adapter_metadata = {}
484489
485490 if not (unet_lora_layers or text_encoder_lora_layers ):
486491 raise ValueError ("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`." )
@@ -491,6 +496,12 @@ def save_lora_weights(
491496 if text_encoder_lora_layers :
492497 state_dict .update (cls .pack_weights (text_encoder_lora_layers , cls .text_encoder_name ))
493498
499+ if unet_lora_adapter_metadata is not None :
500+ lora_adapter_metadata .update (cls .pack_weights (unet_lora_adapter_metadata , cls .unet_name ))
501+
502+ if text_encoder_lora_adapter_metadata :
503+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_lora_adapter_metadata , cls .text_encoder_name ))
504+
494505 # Save the model
495506 cls .write_lora_layers (
496507 state_dict = state_dict ,
@@ -499,6 +510,7 @@ def save_lora_weights(
499510 weight_name = weight_name ,
500511 save_function = save_function ,
501512 safe_serialization = safe_serialization ,
513+ lora_adapter_metadata = lora_adapter_metadata ,
502514 )
503515
504516 def fuse_lora (
0 commit comments