@@ -1668,7 +1668,8 @@ def convert_outputs(self, outputs):
16681668
16691669
16701670 # Store model
1671- def save (self , path , use_onnx = False , path_onnx = None , use_openvino = False , use_safetensors = False , path_openvino = None ,
1671+ def save (self , path , use_onnx = False , path_onnx = None , use_openvino = False , path_openvino = None ,
1672+ use_safetensors = False , path_safetensors = None ,
16721673 batch_size = None , opset_version = 13 , pickle_module = dill , ignore_attributes = None , ** kwargs ):
16731674 """ Save underlying PyTorch model along with meta parameters (config, device spec, etc).
16741675
@@ -1761,7 +1762,11 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, use_saf
17611762 elif use_safetensors :
17621763 from safetensors .torch import save_file
17631764 state_dict = self .model .state_dict ()
1764- save_file (state_dict , path )
1765+
1766+ path_safetensors = path_safetensors or (path + "safetensors" )
1767+ save_file (state_dict , path_safetensors )
1768+ torch .save ({'safetensors' : True , 'path_safetensors' : path_safetensors , ** preserved_dict },
1769+ path , pickle_module = pickle_module , ** kwargs )
17651770
17661771 else :
17671772 preserved = set (self .PRESERVE ) - set (ignore_attributes )
0 commit comments