@@ -1669,7 +1669,7 @@ def convert_outputs(self, outputs):
16691669
16701670
16711671 # Store model
1672- def save (self , path , use_onnx = False , path_onnx = None , use_openvino = False , path_openvino = None ,
1672+ def save (self , path = None , use_onnx = False , path_onnx = None , use_openvino = False , path_openvino = None ,
16731673 use_safetensors = False , path_safetensors = None , pickle_metadata = False ,
16741674 batch_size = None , opset_version = 13 , pickle_module = dill , ignore_attributes = None , ** kwargs ):
16751675 """ Save underlying PyTorch model along with meta parameters (config, device spec, etc).
@@ -1681,7 +1681,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
16811681
16821682 Parameters
16831683 ----------
1684- path : str
1684+ path : Optional[ str]
16851685 Path to a file where the model data will be stored.
16861686 use_onnx: bool
16871687 Whether to store model in ONNX format.
@@ -1709,9 +1709,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17091709 kwargs : dict
17101710 Other keyword arguments, passed directly to :func:`torch.save`.
17111711 """
1712- dirname = os .path .dirname (path )
1713- if dirname and not os .path .exists (dirname ):
1714- os .makedirs (dirname )
1712+
1713+ if path is None :
1714+ dirname = os .getcwd ()
1715+ path = os .path .join (dirname , "model.pt" )
1716+ else :
1717+ dirname = os .path .dirname (path )
1718+ if dirname and not os .path .exists (dirname ):
1719+ os .makedirs (dirname )
17151720
17161721 # Unwrap DDP model
17171722 if isinstance (self .model , torch .nn .parallel .DistributedDataParallel ):
@@ -1728,7 +1733,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17281733 raise ValueError ('Specify valid `batch_size`, used for model inference!' )
17291734
17301735 inputs = self .make_placeholder_data (batch_size = batch_size , unwrap = False )
1731- path_onnx = path_onnx or ( path + ' .onnx' )
1736+ path_onnx = path_onnx or os . path . join ( dirname , "model .onnx" )
17321737 torch .onnx .export (self .model .eval (), inputs , path_onnx , opset_version = opset_version )
17331738
17341739 if pickle_metadata :
@@ -1742,7 +1747,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17421747 elif use_openvino :
17431748 import openvino as ov
17441749
1745- path_openvino = path_openvino or ( path + ' .openvino' )
1750+ path_openvino = path_openvino or os . path . join ( dirname , "model .openvino" )
17461751 if os .path .splitext (path_openvino )[- 1 ] == '' :
17471752 path_openvino = f'{ path_openvino } .xml'
17481753
@@ -1766,7 +1771,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17661771 from safetensors .torch import save_file
17671772 state_dict = self .model .state_dict ()
17681773
1769- path_safetensors = path_safetensors or ( path + ' .safetensors' )
1774+ path_safetensors = path_safetensors or os . path . join ( dirname , "model .safetensors" )
17701775 save_file (state_dict , path_safetensors )
17711776
17721777 preserved = self .PRESERVE_SAFETENSORS - ignore_attributes
0 commit comments