@@ -2267,6 +2267,8 @@ def save_lora_weights(
22672267 weight_name : str = None ,
22682268 save_function : Callable = None ,
22692269 safe_serialization : bool = True ,
2270+ transformer_lora_adapter_metadata = None ,
2271+ text_encoder_lora_adapter_metadata = None ,
22702272 ):
22712273 r"""
22722274 Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -2289,8 +2291,11 @@ def save_lora_weights(
22892291 `DIFFUSERS_SAVE_MODE`.
22902292 safe_serialization (`bool`, *optional*, defaults to `True`):
22912293 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2294+ transformer_lora_adapter_metadata: TODO
2295+ text_encoder_lora_adapter_metadata: TODO
22922296 """
22932297 state_dict = {}
2298+ lora_adapter_metadata = {}
22942299
22952300 if not (transformer_lora_layers or text_encoder_lora_layers ):
22962301 raise ValueError ("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`." )
@@ -2301,6 +2306,12 @@ def save_lora_weights(
23012306 if text_encoder_lora_layers :
23022307 state_dict .update (cls .pack_weights (text_encoder_lora_layers , cls .text_encoder_name ))
23032308
2309+ if transformer_lora_adapter_metadata is not None :
2310+ lora_adapter_metadata .update (cls .pack_weights (transformer_lora_adapter_metadata , cls .transformer_name ))
2311+
2312+ if text_encoder_lora_adapter_metadata :
2313+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_lora_adapter_metadata , cls .text_encoder_name ))
2314+
23042315 # Save the model
23052316 cls .write_lora_layers (
23062317 state_dict = state_dict ,
@@ -2309,6 +2320,7 @@ def save_lora_weights(
23092320 weight_name = weight_name ,
23102321 save_function = save_function ,
23112322 safe_serialization = safe_serialization ,
2323+ lora_adapter_metadata = lora_adapter_metadata ,
23122324 )
23132325
23142326 def fuse_lora (
0 commit comments