Skip to content

Commit 7f46222

Browse files
author
igor
committed
Update load
1 parent f6426bd commit 7f46222

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

batchflow/models/torch/base.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,34 +1799,43 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
17991799
else:
18001800
self._parse_devices()
18011801

1802-
if isinstance(file, str) and file.endswith(".safetensors"):
1803-
from safetensors.torch import load_file
1804-
state_dict = load_file(file, device=device)
1805-
self.model.load_state_dict(state_dict)
1802+
if isinstance(file, str):
1803+
if file.endswith(".safetensors"):
1804+
from safetensors.torch import load_file
1805+
state_dict = load_file(file, device=device)
1806+
self.model.load_state_dict(state_dict)
18061807

1807-
self.model_to_device()
1808+
self.model_to_device()
18081809

1809-
if make_infrastructure:
1810-
self.make_infrastructure()
1810+
if make_infrastructure:
1811+
self.make_infrastructure()
18111812

1812-
self.set_model_mode(mode)
1813+
self.set_model_mode(mode)
18131814

1814-
return
1815-
elif isinstance(file, str) and file.endswith(".onnx"):
1816-
try:
1817-
from onnx2torch import convert
1818-
except ImportError as e:
1819-
raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e
1815+
return
1816+
elif file.endswith(".onnx"):
1817+
try:
1818+
from onnx2torch import convert
1819+
except ImportError as e:
1820+
raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e
18201821

1821-
model = convert(file).eval()
1822-
self.model = model
1822+
model = convert(file).eval()
1823+
self.model = model
18231824

1824-
self.model_to_device()
1825+
self.model_to_device()
18251826

1826-
if make_infrastructure:
1827-
self.make_infrastructure()
1827+
if make_infrastructure:
1828+
self.make_infrastructure()
18281829

1829-
self.set_model_mode(mode)
1830+
self.set_model_mode(mode)
1831+
elif file.endswith(".openvino"):
1832+
model = OVModel(model_path=file, **model_load_kwargs)
1833+
self.model = model
1834+
1835+
self._loaded_from_openvino = True
1836+
self.disable_training = True
1837+
1838+
return
18301839

18311840
kwargs['map_location'] = self.device
18321841

@@ -1863,6 +1872,11 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
18631872
self._loaded_from_onnx = True
18641873
self.disable_training = True
18651874

1875+
if "safetensors" in checkpoint:
1876+
from safetensors.torch import load_file
1877+
state_dict = load_file(checkpoint['path_safetensors'], device=device)
1878+
self.model.load_state_dict(state_dict)
1879+
18661880
self.model_to_device()
18671881

18681882
if make_infrastructure:

0 commit comments

Comments
 (0)