Skip to content

Commit 9fed358

Browse files
author
igor
committed
Update path logic. Make path optional
1 parent 6b1feba commit 9fed358

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

batchflow/models/torch/base.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ def convert_outputs(self, outputs):
16691669

16701670

16711671
# Store model
1672-
def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None,
1672+
def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None,
16731673
use_safetensors=False, path_safetensors=None, pickle_metadata=False,
16741674
batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs):
16751675
""" Save underlying PyTorch model along with meta parameters (config, device spec, etc).
@@ -1681,7 +1681,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
16811681
16821682
Parameters
16831683
----------
1684-
path : str
1684+
path : Optional[str]
16851685
Path to a file where the model data will be stored.
16861686
use_onnx: bool
16871687
Whether to store model in ONNX format.
@@ -1709,9 +1709,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17091709
kwargs : dict
17101710
Other keyword arguments, passed directly to :func:`torch.save`.
17111711
"""
1712-
dirname = os.path.dirname(path)
1713-
if dirname and not os.path.exists(dirname):
1714-
os.makedirs(dirname)
1712+
1713+
if path is None:
1714+
dirname = os.getcwd()
1715+
path = os.path.join(dirname, "model.pt")
1716+
else:
1717+
dirname = os.path.dirname(path)
1718+
if dirname and not os.path.exists(dirname):
1719+
os.makedirs(dirname)
17151720

17161721
# Unwrap DDP model
17171722
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
@@ -1728,7 +1733,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17281733
raise ValueError('Specify valid `batch_size`, used for model inference!')
17291734

17301735
inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False)
1731-
path_onnx = path_onnx or (path + '.onnx')
1736+
path_onnx = path_onnx or os.path.join(dirname, "model.onnx")
17321737
torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version)
17331738

17341739
if pickle_metadata:
@@ -1742,7 +1747,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17421747
elif use_openvino:
17431748
import openvino as ov
17441749

1745-
path_openvino = path_openvino or (path + '.openvino')
1750+
path_openvino = path_openvino or os.path.join(dirname, "model.openvino")
17461751
if os.path.splitext(path_openvino)[-1] == '':
17471752
path_openvino = f'{path_openvino}.xml'
17481753

@@ -1766,7 +1771,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17661771
from safetensors.torch import save_file
17671772
state_dict = self.model.state_dict()
17681773

1769-
path_safetensors = path_safetensors or (path + '.safetensors')
1774+
path_safetensors = path_safetensors or os.path.join(dirname, "model.safetensors")
17701775
save_file(state_dict, path_safetensors)
17711776

17721777
preserved = self.PRESERVE_SAFETENSORS - ignore_attributes

0 commit comments

Comments
 (0)