Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8a2fc15
Add save in safetensors format
Jun 9, 2025
bdf275a
Add safetensors to reqs
Jun 9, 2025
b000c74
Add loading safetensors
Jun 9, 2025
a86d4e4
Add logic for onnx weights
Jun 9, 2025
1192a21
Update load
Jun 10, 2025
d0de633
Update save method
Jun 10, 2025
e3f15db
Add pickle_metadata flag
Jun 10, 2025
ba767b0
Add return for onnx load
Jun 10, 2025
f7cc818
Add preserved_dict for safetensors
Jun 10, 2025
d5bba46
Update batchflow/models/torch/base.py
igor-iusupov Jun 10, 2025
e515c65
Update saving extensions
Jun 10, 2025
b491290
Update load model for safetensors
Jun 10, 2025
3d9392c
Small update
Jun 10, 2025
5304560
Update path logic. Make path optional
Jun 10, 2025
7305ea9
Fix for linter
Jun 10, 2025
7b5217a
Remove all use_xxx and add format=name
Jun 11, 2025
9476e3c
AssertError -> ValueError
Jun 11, 2025
88f7378
Format -> fmt
Jun 11, 2025
23d4d94
Add fmt to load
Jun 11, 2025
fcbb4b7
Update saving model name
Jun 11, 2025
5b0f63d
Small update
Jun 11, 2025
53d6cc9
Update batchflow/models/torch/base.py
igor-iusupov Jun 11, 2025
19b135d
Update docstring
Jun 11, 2025
8ed42de
Remove device from load_file func
Jun 11, 2025
ed57f96
Make fmt optional
Jun 16, 2025
169f58a
Update model init
Jun 16, 2025
429d462
Fix model_to_device
AlexeyKozhevin Jun 27, 2025
f632dea
Mark test_redirect_stdout as slow
AlexeyKozhevin Jun 27, 2025
2a8fce9
Refactoring
AlexeyKozhevin Jun 27, 2025
c298bd1
Fix names collisions
AlexeyKozhevin Jun 27, 2025
eb306ba
Update tests and load/save
AlexeyKozhevin Jun 30, 2025
be86d99
Change logic
AlexeyKozhevin Jun 30, 2025
739af96
Minor changes
AlexeyKozhevin Jun 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 106 additions & 30 deletions batchflow/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch import nn
from torch.optim.swa_utils import AveragedModel, SWALR


from sklearn.decomposition import PCA

from ...utils_import import make_delayed_import
Expand Down Expand Up @@ -400,6 +399,7 @@ def callable_init(module): # example of a callable for init

PRESERVE_ONNX = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
PRESERVE_OPENVINO = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
PRESERVE_SAFETENSORS = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])

def __init__(self, config=None):
if not isinstance(config, (dict, Config)):
Expand Down Expand Up @@ -987,12 +987,16 @@ def transfer_from_device(self, data, force_float32_dtype=True):

def model_to_device(self, model=None):
""" Put model on device(s). If needed, apply DataParallel wrapper. """
model = model if model is not None else self.model
model_ = model if model is not None else self.model

if len(self.devices) > 1:
model = nn.DataParallel(model, self.devices)
model_ = nn.DataParallel(model_, self.devices)
else:
model = model.to(self.device)
model_ = model_.to(self.device)

if model is None:
self.model = model_
return model_


# Apply model to train/predict on given data
Expand Down Expand Up @@ -1669,7 +1673,7 @@ def convert_outputs(self, outputs):


# Store model
def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None,
def save(self, path, fmt=None, pickle_metadata=True,
batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs):
""" Save underlying PyTorch model along with meta parameters (config, device spec, etc).

Expand All @@ -1682,17 +1686,10 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
----------
path : str
Path to a file where the model data will be stored.
use_onnx: bool
Whether to store model in ONNX format.
path_onnx : str, optional
Used only if `use_onnx` is True.
If provided, then path to store the ONNX model; default `path_onnx` is `path` with '_onnx' postfix.
use_openvino: bool
Whether to store model as openvino xml file.
path_openvino : str, optional
Used only if `use_openvino` is True.
If provided, then path to store the openvino model; default `path_openvino` is `path` with '_openvino'
postfix.
fmt: Optional[str]
Weights format. Available formats: "pt", "onnx", "openvino", "safetensors"
pickle_metadata: bool
Whether make pickle with metadata
batch_size : int, optional
Used only if `use_onnx` is True.
Fixed batch size of the ONNX module. This is the only viable batch size for this model after loading.
Expand All @@ -1706,6 +1703,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
kwargs : dict
Other keyword arguments, passed directly to :func:`torch.save`.
"""
available_formats = ("pt", "onnx", "openvino", "safetensors")

if fmt is None:
fmt = os.path.splitext(path)[-1][1:]

if fmt not in available_formats:
raise ValueError(f"fmt must be in {available_formats}")

dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
Expand All @@ -1720,25 +1725,27 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
ignore_attributes = []
ignore_attributes = set(ignore_attributes)

if use_onnx:
if fmt == "onnx":
if batch_size is None:
raise ValueError('Specify valid `batch_size`, used for model inference!')

inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False)
path_onnx = path_onnx or (path + '_onnx')

path_onnx = path if not pickle_metadata else os.path.splitext(path)[0] + ".onnx"
torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version)

# Save the rest of parameters
preserved = self.PRESERVE_ONNX - ignore_attributes
if pickle_metadata:
# Save the rest of parameters
preserved = self.PRESERVE_ONNX - ignore_attributes

preserved_dict = {item: getattr(self, item) for item in preserved}
torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)
preserved_dict = {item: getattr(self, item) for item in preserved}
torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)

elif use_openvino:
elif fmt == "openvino":
import openvino as ov

path_openvino = path_openvino or (path + '_openvino')
path_openvino = path if not pickle_metadata else os.path.splitext(path)[0] + ".openvino"
if os.path.splitext(path_openvino)[-1] == '':
path_openvino = f'{path_openvino}.xml'

Expand All @@ -1751,18 +1758,33 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op

ov.save_model(model, output_model=path_openvino)

# Save the rest of parameters
preserved = self.PRESERVE_OPENVINO - ignore_attributes
if pickle_metadata:
# Save the rest of parameters
preserved = self.PRESERVE_OPENVINO - ignore_attributes
preserved_dict = {item: getattr(self, item) for item in preserved}
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)

elif fmt == "safetensors":
from safetensors.torch import save_file
state_dict = self.model.state_dict()

path_safetensors = path if not pickle_metadata else os.path.splitext(path)[0] + ".safetensors"
save_file(state_dict, path_safetensors)

preserved = self.PRESERVE_SAFETENSORS - ignore_attributes
preserved_dict = {item: getattr(self, item) for item in preserved}
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)

if pickle_metadata:
torch.save({'safetensors': True, 'path_safetensors': path_safetensors, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)

else:
preserved = set(self.PRESERVE) - set(ignore_attributes)
torch.save({item: getattr(self, item) for item in preserved},
path, pickle_module=pickle_module, **kwargs)

def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
def load(self, file, fmt=None, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
""" Load a torch model from a file.

If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size
Expand All @@ -1772,6 +1794,8 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
----------
file : str, PathLike, io.Bytes
a file where a model is stored.
fmt: optional str
Weights format. Available formats: "pt", "onnx", "openvino", "safetensors"
make_infrastructure : bool
Whether to re-create model loss, optimizer, scaler and decay.
mode : str
Expand All @@ -1793,6 +1817,53 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
else:
self._parse_devices()

if isinstance(file, str):
if fmt == "safetensors" or (fmt is None and file.endswith(".safetensors")):
from safetensors.torch import load_file
state_dict = load_file(file)

inputs = self.make_placeholder_data(to_device=True)
with torch.no_grad():
self.model = Network(inputs=inputs, config=self.config, device=self.device)

self.model.load_state_dict(state_dict)

self.model_to_device()

if make_infrastructure:
self.make_infrastructure()

self.set_model_mode(mode)

return

if fmt == "onnx" or (fmt is None and file.endswith(".onnx")):
try:
from onnx2torch import convert
except ImportError as e:
raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e

model = convert(file).eval()
self.model = model

self.model_to_device()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that call, model.load('tmp.safetensors', fmt='safetensors', pickle_metadata=False, device='cpu') create model on cuda


if make_infrastructure:
self.make_infrastructure()

self.set_model_mode(mode)

return

if fmt == "openvino" or (fmt is None and file.endswith(".openvino")):
model = OVModel(model_path=file, **model_load_kwargs)
self.model = model

self._loaded_from_openvino = True
self.disable_training = True

return

kwargs['map_location'] = self.device

# Load items from disk storage and set them as insance attributes
Expand Down Expand Up @@ -1828,6 +1899,11 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
self._loaded_from_onnx = True
self.disable_training = True

if "safetensors" in checkpoint:
from safetensors.torch import load_file
state_dict = load_file(checkpoint['path_safetensors'], device=device)
self.model.load_state_dict(state_dict)

self.model_to_device()

if make_infrastructure:
Expand Down
1 change: 1 addition & 0 deletions batchflow/tests/research_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def f(a):
assert research.results.df.iloc[0].a == f(2)
assert research.results.df.iloc[0].b == f(3)

@pytest.mark.slow
@pytest.mark.parametrize('dump_results', [False, True])
@pytest.mark.parametrize('redirect_stdout', [True, 0, 1, 2, 3])
@pytest.mark.parametrize('redirect_stderr', [True, 0, 1, 2, 3])
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ telegram = [
"pillow>=9.4,<11.0",
]

safetensors = [
"safetensors>=0.5.3",
]

other = [
"urllib3>=1.25"
]
Expand Down