@@ -509,35 +509,29 @@ def save_lora_weights(
509509            text_encoder_lora_adapter_metadata: 
510510                LoRA adapter metadata associated with the text encoder to be serialized with the state dict. 
511511        """ 
512-         state_dict  =  {}
513-         lora_adapter_metadata  =  {}
514- 
515-         if  not  (unet_lora_layers  or  text_encoder_lora_layers ):
516-             raise  ValueError ("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`." )
512+         lora_layers  =  {}
513+         lora_metadata  =  {}
517514
518515        if  unet_lora_layers :
519-             state_dict .update (cls .pack_weights (unet_lora_layers , cls .unet_name ))
516+             lora_layers [cls .unet_name ] =  unet_lora_layers 
517+             lora_metadata [cls .unet_name ] =  unet_lora_adapter_metadata 
520518
521519        if  text_encoder_lora_layers :
522-             state_dict .update (cls .pack_weights (text_encoder_lora_layers , cls .text_encoder_name ))
520+             lora_layers [cls .text_encoder_name ] =  text_encoder_lora_layers 
521+             lora_metadata [cls .text_encoder_name ] =  text_encoder_lora_adapter_metadata 
523522
524-         if  unet_lora_adapter_metadata :
525-             lora_adapter_metadata . update ( _pack_dict_with_prefix ( unet_lora_adapter_metadata ,  cls . unet_name ) )
523+         if  not   lora_layers :
524+             raise   ValueError ( "You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`." )
526525
527-         if  text_encoder_lora_adapter_metadata :
528-             lora_adapter_metadata .update (
529-                 _pack_dict_with_prefix (text_encoder_lora_adapter_metadata , cls .text_encoder_name )
530-             )
531- 
532-         # Save the model 
533-         cls .write_lora_layers (
534-             state_dict = state_dict ,
526+         # Delegate to the base helper method 
527+         cls ._save_lora_weights (
535528            save_directory = save_directory ,
529+             lora_layers = lora_layers ,
530+             lora_metadata = lora_metadata ,
536531            is_main_process = is_main_process ,
537532            weight_name = weight_name ,
538533            save_function = save_function ,
539534            safe_serialization = safe_serialization ,
540-             lora_adapter_metadata = lora_adapter_metadata ,
541535        )
542536
543537    def  fuse_lora (
@@ -1003,44 +997,34 @@ def save_lora_weights(
1003997            text_encoder_2_lora_adapter_metadata: 
1004998                LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. 
1005999        """ 
1006-         state_dict  =  {}
1007-         lora_adapter_metadata  =  {}
1008- 
1009-         if  not  (unet_lora_layers  or  text_encoder_lora_layers  or  text_encoder_2_lora_layers ):
1010-             raise  ValueError (
1011-                 "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." 
1012-             )
1000+         lora_layers  =  {}
1001+         lora_metadata  =  {}
10131002
10141003        if  unet_lora_layers :
1015-             state_dict .update (cls .pack_weights (unet_lora_layers , cls .unet_name ))
1004+             lora_layers [cls .unet_name ] =  unet_lora_layers 
1005+             lora_metadata [cls .unet_name ] =  unet_lora_adapter_metadata 
10161006
10171007        if  text_encoder_lora_layers :
1018-             state_dict .update (cls .pack_weights (text_encoder_lora_layers , "text_encoder" ))
1008+             lora_layers ["text_encoder" ] =  text_encoder_lora_layers 
1009+             lora_metadata ["text_encoder" ] =  text_encoder_lora_adapter_metadata 
10191010
10201011        if  text_encoder_2_lora_layers :
1021-             state_dict .update (cls .pack_weights (text_encoder_2_lora_layers , "text_encoder_2" ))
1012+             lora_layers ["text_encoder_2" ] =  text_encoder_2_lora_layers 
1013+             lora_metadata ["text_encoder_2" ] =  text_encoder_2_lora_adapter_metadata 
10221014
1023-         if  unet_lora_adapter_metadata  is  not None :
1024-             lora_adapter_metadata .update (_pack_dict_with_prefix (unet_lora_adapter_metadata , cls .unet_name ))
1025- 
1026-         if  text_encoder_lora_adapter_metadata :
1027-             lora_adapter_metadata .update (
1028-                 _pack_dict_with_prefix (text_encoder_lora_adapter_metadata , cls .text_encoder_name )
1029-             )
1030- 
1031-         if  text_encoder_2_lora_adapter_metadata :
1032-             lora_adapter_metadata .update (
1033-                 _pack_dict_with_prefix (text_encoder_2_lora_adapter_metadata , "text_encoder_2" )
1015+         if  not  lora_layers :
1016+             raise  ValueError (
1017+                 "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`." 
10341018            )
10351019
1036-         cls .write_lora_layers (
1037-             state_dict = state_dict ,
1020+         cls ._save_lora_weights (
10381021            save_directory = save_directory ,
1022+             lora_layers = lora_layers ,
1023+             lora_metadata = lora_metadata ,
10391024            is_main_process = is_main_process ,
10401025            weight_name = weight_name ,
10411026            save_function = save_function ,
10421027            safe_serialization = safe_serialization ,
1043-             lora_adapter_metadata = lora_adapter_metadata ,
10441028        )
10451029
10461030    def  fuse_lora (
@@ -1466,46 +1450,34 @@ def save_lora_weights(
14661450            text_encoder_2_lora_adapter_metadata: 
14671451                LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. 
14681452        """ 
1469-         state_dict  =  {}
1470-         lora_adapter_metadata  =  {}
1471- 
1472-         if  not  (transformer_lora_layers  or  text_encoder_lora_layers  or  text_encoder_2_lora_layers ):
1473-             raise  ValueError (
1474-                 "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." 
1475-             )
1453+         lora_layers  =  {}
1454+         lora_metadata  =  {}
14761455
14771456        if  transformer_lora_layers :
1478-             state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
1457+             lora_layers [cls .transformer_name ] =  transformer_lora_layers 
1458+             lora_metadata [cls .transformer_name ] =  transformer_lora_adapter_metadata 
14791459
14801460        if  text_encoder_lora_layers :
1481-             state_dict .update (cls .pack_weights (text_encoder_lora_layers , "text_encoder" ))
1461+             lora_layers ["text_encoder" ] =  text_encoder_lora_layers 
1462+             lora_metadata ["text_encoder" ] =  text_encoder_lora_adapter_metadata 
14821463
14831464        if  text_encoder_2_lora_layers :
1484-             state_dict .update (cls .pack_weights (text_encoder_2_lora_layers , "text_encoder_2" ))
1465+             lora_layers ["text_encoder_2" ] =  text_encoder_2_lora_layers 
1466+             lora_metadata ["text_encoder_2" ] =  text_encoder_2_lora_adapter_metadata 
14851467
1486-         if  transformer_lora_adapter_metadata  is  not None :
1487-             lora_adapter_metadata .update (
1488-                 _pack_dict_with_prefix (transformer_lora_adapter_metadata , cls .transformer_name )
1489-             )
1490- 
1491-         if  text_encoder_lora_adapter_metadata :
1492-             lora_adapter_metadata .update (
1493-                 _pack_dict_with_prefix (text_encoder_lora_adapter_metadata , cls .text_encoder_name )
1494-             )
1495- 
1496-         if  text_encoder_2_lora_adapter_metadata :
1497-             lora_adapter_metadata .update (
1498-                 _pack_dict_with_prefix (text_encoder_2_lora_adapter_metadata , "text_encoder_2" )
1468+         if  not  lora_layers :
1469+             raise  ValueError (
1470+                 "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`." 
14991471            )
15001472
1501-         cls .write_lora_layers (
1502-             state_dict = state_dict ,
1473+         cls ._save_lora_weights (
15031474            save_directory = save_directory ,
1475+             lora_layers = lora_layers ,
1476+             lora_metadata = lora_metadata ,
15041477            is_main_process = is_main_process ,
15051478            weight_name = weight_name ,
15061479            save_function = save_function ,
15071480            safe_serialization = safe_serialization ,
1508-             lora_adapter_metadata = lora_adapter_metadata ,
15091481        )
15101482
15111483    # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer 
@@ -2434,37 +2406,29 @@ def save_lora_weights(
24342406            text_encoder_lora_adapter_metadata: 
24352407                LoRA adapter metadata associated with the text encoder to be serialized with the state dict. 
24362408        """ 
2437-         state_dict  =  {}
2438-         lora_adapter_metadata  =  {}
2439- 
2440-         if  not  (transformer_lora_layers  or  text_encoder_lora_layers ):
2441-             raise  ValueError ("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`." )
2409+         lora_layers  =  {}
2410+         lora_metadata  =  {}
24422411
24432412        if  transformer_lora_layers :
2444-             state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
2413+             lora_layers [cls .transformer_name ] =  transformer_lora_layers 
2414+             lora_metadata [cls .transformer_name ] =  transformer_lora_adapter_metadata 
24452415
24462416        if  text_encoder_lora_layers :
2447-             state_dict .update (cls .pack_weights (text_encoder_lora_layers , cls .text_encoder_name ))
2448- 
2449-         if  transformer_lora_adapter_metadata :
2450-             lora_adapter_metadata .update (
2451-                 _pack_dict_with_prefix (transformer_lora_adapter_metadata , cls .transformer_name )
2452-             )
2417+             lora_layers [cls .text_encoder_name ] =  text_encoder_lora_layers 
2418+             lora_metadata [cls .text_encoder_name ] =  text_encoder_lora_adapter_metadata 
24532419
2454-         if  text_encoder_lora_adapter_metadata :
2455-             lora_adapter_metadata .update (
2456-                 _pack_dict_with_prefix (text_encoder_lora_adapter_metadata , cls .text_encoder_name )
2457-             )
2420+         if  not  lora_layers :
2421+             raise  ValueError ("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`." )
24582422
2459-         # Save the model 
2460-         cls .write_lora_layers (
2461-             state_dict = state_dict ,
2423+         # Delegate to the base helper method 
2424+         cls ._save_lora_weights (
24622425            save_directory = save_directory ,
2426+             lora_layers = lora_layers ,
2427+             lora_metadata = lora_metadata ,
24632428            is_main_process = is_main_process ,
24642429            weight_name = weight_name ,
24652430            save_function = save_function ,
24662431            safe_serialization = safe_serialization ,
2467-             lora_adapter_metadata = lora_adapter_metadata ,
24682432        )
24692433
24702434    def  fuse_lora (
0 commit comments