@@ -1669,8 +1669,7 @@ def convert_outputs(self, outputs):
16691669
16701670
16711671 # Store model
1672- def save (self , path = None , use_onnx = False , path_onnx = None , use_openvino = False , path_openvino = None ,
1673- use_safetensors = False , path_safetensors = None , pickle_metadata = False ,
1672+ def save (self , path , format = "pt" , pickle_metadata = False ,
16741673 batch_size = None , opset_version = 13 , pickle_module = dill , ignore_attributes = None , ** kwargs ):
16751674 """ Save underlying PyTorch model along with meta parameters (config, device spec, etc).
16761675
@@ -1681,7 +1680,7 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
16811680
16821681 Parameters
16831682 ----------
1684- path : Optional[ str]
1683+ path : str
16851684 Path to a file where the model data will be stored.
16861685 use_onnx: bool
16871686 Whether to store model in ONNX format.
@@ -1709,14 +1708,11 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17091708 kwargs : dict
17101709 Other keyword arguments, passed directly to :func:`torch.save`.
17111710 """
1712-
1713- if path is None :
1714- dirname = os .getcwd ()
1715- path = os .path .join (dirname , "model.pt" )
1716- else :
1717- dirname = os .path .dirname (path )
1718- if dirname and not os .path .exists (dirname ):
1719- os .makedirs (dirname )
1711+ available_formats = ("pt" , "onnx" , "openvino" , "safetensors" )
1712+ assert format in available_formats , f"Format must be in { available_formats } "
1713+ dirname = os .path .dirname (path )
1714+ if dirname and not os .path .exists (dirname ):
1715+ os .makedirs (dirname )
17201716
17211717 # Unwrap DDP model
17221718 if isinstance (self .model , torch .nn .parallel .DistributedDataParallel ):
@@ -1728,12 +1724,12 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17281724 ignore_attributes = []
17291725 ignore_attributes = set (ignore_attributes )
17301726
1731- if use_onnx :
1727+ if format == "onnx" :
17321728 if batch_size is None :
17331729 raise ValueError ('Specify valid `batch_size`, used for model inference!' )
17341730
17351731 inputs = self .make_placeholder_data (batch_size = batch_size , unwrap = False )
1736- path_onnx = path_onnx or os .path .join (dirname , "model.onnx" )
1732+ path_onnx = path if path . endswith ( ".onnx" ) else os .path .join (dirname , "model.onnx" )
17371733 torch .onnx .export (self .model .eval (), inputs , path_onnx , opset_version = opset_version )
17381734
17391735 if pickle_metadata :
@@ -1744,10 +1740,10 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17441740 torch .save ({'onnx' : True , 'path_onnx' : path_onnx , 'onnx_batch_size' : batch_size , ** preserved_dict },
17451741 path , pickle_module = pickle_module , ** kwargs )
17461742
1747- elif use_openvino :
1743+ elif format == "openvino" :
17481744 import openvino as ov
17491745
1750- path_openvino = path_openvino or os .path .join (dirname , "model.openvino" )
1746+ path_openvino = path if path . endswith ( ".openvino" ) else os .path .join (dirname , "model.openvino" )
17511747 if os .path .splitext (path_openvino )[- 1 ] == '' :
17521748 path_openvino = f'{ path_openvino } .xml'
17531749
@@ -1767,11 +1763,11 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17671763 torch .save ({'openvino' : True , 'path_openvino' : path_openvino , ** preserved_dict },
17681764 path , pickle_module = pickle_module , ** kwargs )
17691765
1770- elif use_safetensors :
1766+ elif format == "safetensors" :
17711767 from safetensors .torch import save_file
17721768 state_dict = self .model .state_dict ()
17731769
1774- path_safetensors = path_safetensors or os .path .join (dirname , "model.safetensors" )
1770+ path_safetensors = path if path . endswith ( ".safetensors" ) else os .path .join (dirname , "model.safetensors" )
17751771 save_file (state_dict , path_safetensors )
17761772
17771773 preserved = self .PRESERVE_SAFETENSORS - ignore_attributes
0 commit comments