@@ -3215,28 +3215,24 @@ def save_lora_weights(
32153215 transformer_lora_adapter_metadata:
32163216 LoRA adapter metadata associated with the transformer to be serialized with the state dict.
32173217 """
3218- state_dict = {}
3219- lora_adapter_metadata = {}
3220-
3221- if not transformer_lora_layers :
3222- raise ValueError ("You must pass `transformer_lora_layers`." )
3218+ lora_layers = {}
3219+ lora_metadata = {}
32233220
3224- state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
3221+ if transformer_lora_layers :
3222+ lora_layers [cls .transformer_name ] = transformer_lora_layers
3223+ lora_metadata [cls .transformer_name ] = transformer_lora_adapter_metadata
32253224
3226- if transformer_lora_adapter_metadata is not None :
3227- lora_adapter_metadata .update (
3228- _pack_dict_with_prefix (transformer_lora_adapter_metadata , cls .transformer_name )
3229- )
3225+ if not lora_layers :
3226+ raise ValueError ("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`." )
32303227
3231- # Save the model
3232- cls .write_lora_layers (
3233- state_dict = state_dict ,
3228+ cls ._save_lora_weights (
32343229 save_directory = save_directory ,
3230+ lora_layers = lora_layers ,
3231+ lora_metadata = lora_metadata ,
32353232 is_main_process = is_main_process ,
32363233 weight_name = weight_name ,
32373234 save_function = save_function ,
32383235 safe_serialization = safe_serialization ,
3239- lora_adapter_metadata = lora_adapter_metadata ,
32403236 )
32413237
32423238 def fuse_lora (
0 commit comments