Skip to content

Commit c6c0ec3

Browse files
author
igor
committed
Remove all use_xxx and add format=name
1 parent 3169a9b commit c6c0ec3

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

batchflow/models/torch/base.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,8 +1669,7 @@ def convert_outputs(self, outputs):
16691669

16701670

16711671
# Store model
1672-
def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None,
1673-
use_safetensors=False, path_safetensors=None, pickle_metadata=False,
1672+
def save(self, path, format="pt", pickle_metadata=False,
16741673
batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs):
16751674
""" Save underlying PyTorch model along with meta parameters (config, device spec, etc).
16761675
@@ -1681,7 +1680,7 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
16811680
16821681
Parameters
16831682
----------
1684-
path : Optional[str]
1683+
path : str
16851684
Path to a file where the model data will be stored.
16861685
use_onnx: bool
16871686
Whether to store model in ONNX format.
@@ -1709,14 +1708,11 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17091708
kwargs : dict
17101709
Other keyword arguments, passed directly to :func:`torch.save`.
17111710
"""
1712-
1713-
if path is None:
1714-
dirname = os.getcwd()
1715-
path = os.path.join(dirname, "model.pt")
1716-
else:
1717-
dirname = os.path.dirname(path)
1718-
if dirname and not os.path.exists(dirname):
1719-
os.makedirs(dirname)
1711+
available_formats = ("pt", "onnx", "openvino", "safetensors")
1712+
assert format in available_formats, f"Format must be in {available_formats}"
1713+
dirname = os.path.dirname(path)
1714+
if dirname and not os.path.exists(dirname):
1715+
os.makedirs(dirname)
17201716

17211717
# Unwrap DDP model
17221718
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
@@ -1728,12 +1724,12 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17281724
ignore_attributes = []
17291725
ignore_attributes = set(ignore_attributes)
17301726

1731-
if use_onnx:
1727+
if format == "onnx":
17321728
if batch_size is None:
17331729
raise ValueError('Specify valid `batch_size`, used for model inference!')
17341730

17351731
inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False)
1736-
path_onnx = path_onnx or os.path.join(dirname, "model.onnx")
1732+
path_onnx = path if path.endswith(".onnx") else os.path.join(dirname, "model.onnx")
17371733
torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version)
17381734

17391735
if pickle_metadata:
@@ -1744,10 +1740,10 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17441740
torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict},
17451741
path, pickle_module=pickle_module, **kwargs)
17461742

1747-
elif use_openvino:
1743+
elif format == "openvino":
17481744
import openvino as ov
17491745

1750-
path_openvino = path_openvino or os.path.join(dirname, "model.openvino")
1746+
path_openvino = path if path.endswith(".openvino") else os.path.join(dirname, "model.openvino")
17511747
if os.path.splitext(path_openvino)[-1] == '':
17521748
path_openvino = f'{path_openvino}.xml'
17531749

@@ -1767,11 +1763,11 @@ def save(self, path=None, use_onnx=False, path_onnx=None, use_openvino=False, pa
17671763
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
17681764
path, pickle_module=pickle_module, **kwargs)
17691765

1770-
elif use_safetensors:
1766+
elif format == "safetensors":
17711767
from safetensors.torch import save_file
17721768
state_dict = self.model.state_dict()
17731769

1774-
path_safetensors = path_safetensors or os.path.join(dirname, "model.safetensors")
1770+
path_safetensors = path if path.endswith(".safetensors") else os.path.join(dirname, "model.safetensors")
17751771
save_file(state_dict, path_safetensors)
17761772

17771773
preserved = self.PRESERVE_SAFETENSORS - ignore_attributes

0 commit comments

Comments
 (0)