@@ -1064,6 +1064,41 @@ def save_function(weights, filename):
10641064        save_function (state_dict , save_path )
10651065        logger .info (f"Model weights saved in { save_path }  " )
10661066
1067+     @classmethod  
1068+     def  _save_lora_weights (
1069+         cls ,
1070+         save_directory : Union [str , os .PathLike ],
1071+         lora_layers : Dict [str , Dict [str , Union [torch .nn .Module , torch .Tensor ]]],
1072+         lora_metadata : Dict [str , Optional [dict ]],
1073+         is_main_process : bool  =  True ,
1074+         weight_name : str  =  None ,
1075+         save_function : Callable  =  None ,
1076+         safe_serialization : bool  =  True ,
1077+     ):
1078+         """ 
1079+         Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all 
1080+         pipeline types. 
1081+         """ 
1082+         state_dict  =  {}
1083+         final_lora_adapter_metadata  =  {}
1084+ 
1085+         for  prefix , layers  in  lora_layers .items ():
1086+             state_dict .update (cls .pack_weights (layers , prefix ))
1087+ 
1088+         for  prefix , metadata  in  lora_metadata .items ():
1089+             if  metadata :
1090+                 final_lora_adapter_metadata .update (_pack_dict_with_prefix (metadata , prefix ))
1091+ 
1092+         cls .write_lora_layers (
1093+             state_dict = state_dict ,
1094+             save_directory = save_directory ,
1095+             is_main_process = is_main_process ,
1096+             weight_name = weight_name ,
1097+             save_function = save_function ,
1098+             safe_serialization = safe_serialization ,
1099+             lora_adapter_metadata = final_lora_adapter_metadata  if  final_lora_adapter_metadata  else  None ,
1100+         )
1101+ 
10671102    @classmethod  
10681103    def  _optionally_disable_offloading (cls , _pipeline ):
10691104        return  _func_optionally_disable_offloading (_pipeline = _pipeline )
0 commit comments