@@ -1064,6 +1064,41 @@ def save_function(weights, filename):
10641064 save_function (state_dict , save_path )
10651065 logger .info (f"Model weights saved in { save_path } " )
10661066
1067+ @classmethod
1068+ def _save_lora_weights (
1069+ cls ,
1070+ save_directory : Union [str , os .PathLike ],
1071+ lora_layers : Dict [str , Dict [str , Union [torch .nn .Module , torch .Tensor ]]],
1072+ lora_metadata : Dict [str , Optional [dict ]],
1073+ is_main_process : bool = True ,
1074+ weight_name : str = None ,
1075+ save_function : Callable = None ,
1076+ safe_serialization : bool = True ,
1077+ ):
1078+ """
1079+ Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
1080+ pipeline types.
1081+ """
1082+ state_dict = {}
1083+ final_lora_adapter_metadata = {}
1084+
1085+ for prefix , layers in lora_layers .items ():
1086+ state_dict .update (cls .pack_weights (layers , prefix ))
1087+
1088+ for prefix , metadata in lora_metadata .items ():
1089+ if metadata :
1090+ final_lora_adapter_metadata .update (_pack_dict_with_prefix (metadata , prefix ))
1091+
1092+ cls .write_lora_layers (
1093+ state_dict = state_dict ,
1094+ save_directory = save_directory ,
1095+ is_main_process = is_main_process ,
1096+ weight_name = weight_name ,
1097+ save_function = save_function ,
1098+ safe_serialization = safe_serialization ,
1099+ lora_adapter_metadata = final_lora_adapter_metadata if final_lora_adapter_metadata else None ,
1100+ )
1101+
10671102 @classmethod
10681103 def _optionally_disable_offloading (cls , _pipeline ):
10691104 return _func_optionally_disable_offloading (_pipeline = _pipeline )
0 commit comments