diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index 2334fd927..54bd5372b 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -15,7 +15,6 @@ from torch import nn from torch.optim.swa_utils import AveragedModel, SWALR - from sklearn.decomposition import PCA from ...utils_import import make_delayed_import @@ -388,6 +387,8 @@ def callable_init(module): # example of a callable for init 'microbatch_size': 16, # size of microbatches at training } """ + AVAILABLE_FORMATS = ("onnx", "openvino", "safetensors") + PRESERVE = set([ 'full_config', 'config', 'model', 'inputs_shapes', 'targets_shapes', 'classes', @@ -400,9 +401,10 @@ def callable_init(module): # example of a callable for init PRESERVE_ONNX = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) PRESERVE_OPENVINO = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) + PRESERVE_SAFETENSORS = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) def __init__(self, config=None): - if not isinstance(config, (dict, Config)): + if config is not None and not isinstance(config, (dict, Config)): config = {'load/file': config} self.model_lock = Lock() @@ -987,12 +989,16 @@ def transfer_from_device(self, data, force_float32_dtype=True): def model_to_device(self, model=None): """ Put model on device(s). If needed, apply DataParallel wrapper. """ - model = model if model is not None else self.model + model_ = model if model is not None else self.model if len(self.devices) > 1: - model = nn.DataParallel(model, self.devices) + model_ = nn.DataParallel(model_, self.devices) else: - model = model.to(self.device) + model_ = model_.to(self.device) + + if model is None: + self.model = model_ + return model_ # Apply model to train/predict on given data @@ -1669,8 +1675,8 @@ def convert_outputs(self, outputs): # Store model - def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None, - batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs): + def save(self, path, fmt=None, pickle_metadata=True, batch_size=None, opset_version=13, pickle_module=dill, + ignore_attributes=('optimizer', 'decay'), **kwargs): """ Save underlying PyTorch model along with meta parameters (config, device spec, etc). If `use_onnx` is set to True, then the model is converted to ONNX format and stored in a separate file. @@ -1682,17 +1688,10 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op ---------- path : str Path to a file where the model data will be stored. - use_onnx: bool - Whether to store model in ONNX format. - path_onnx : str, optional - Used only if `use_onnx` is True. - If provided, then path to store the ONNX model; default `path_onnx` is `path` with '_onnx' postfix. - use_openvino: bool - Whether to store model as openvino xml file. - path_openvino : str, optional - Used only if `use_openvino` is True. - If provided, then path to store the openvino model; default `path_openvino` is `path` with '_openvino' - postfix. + fmt: Optional[str] + Weights format. Available formats: "onnx", "openvino", "safetensors". + pickle_metadata: bool + Whether dump metadata (see `PRESERVE` attribute) to the file. batch_size : int, optional Used only if `use_onnx` is True. Fixed batch size of the ONNX module. This is the only viable batch size for this model after loading. @@ -1706,11 +1705,13 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op kwargs : dict Other keyword arguments, passed directly to :func:`torch.save`. """ + pickle_module = dill if pickle_module is None else pickle_module + dirname = os.path.dirname(path) if dirname and not os.path.exists(dirname): os.makedirs(dirname) - # Unwrap DDP model + # Unwrap DDP if needed if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model = self.model.module @@ -1720,49 +1721,145 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op ignore_attributes = [] ignore_attributes = set(ignore_attributes) - if use_onnx: - if batch_size is None: - raise ValueError('Specify valid `batch_size`, used for model inference!') + if fmt is None: + self._save_torch(path, pickle_metadata, ignore_attributes, pickle_module, kwargs) + return + + if fmt not in self.AVAILABLE_FORMATS: + raise ValueError(f"fmt must be in {self.AVAILABLE_FORMATS} but got {fmt}!") + + if fmt == "onnx": + self._save_onnx(path, pickle_metadata, batch_size, opset_version, + ignore_attributes, pickle_module, kwargs) + elif fmt == "openvino": + self._save_openvino(path, pickle_metadata, batch_size, + ignore_attributes, pickle_module, kwargs) + elif fmt == "safetensors": + self._save_safetensors(path, pickle_metadata, + ignore_attributes, pickle_module, kwargs) + else: + raise RuntimeError(f"Unsupported format: {fmt}") - inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False) - path_onnx = path_onnx or (path + '_onnx') - torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version) - # Save the rest of parameters - preserved = self.PRESERVE_ONNX - ignore_attributes + def _save_torch(self, path, pickle_metadata, ignore_attributes, pickle_module, kwargs): + """ Save the model in PyTorch format. """ + if pickle_metadata: + preserved = set(self.PRESERVE) - ignore_attributes + saved_data = {item: getattr(self, item) for item in preserved} + torch.save(saved_data, path, pickle_module=pickle_module, **kwargs) + else: + torch.save({'model': self.model}, path, pickle_module=pickle_module, **kwargs) + + + def _save_onnx(self, path, pickle_metadata, batch_size, opset_version, + ignore_attributes, pickle_module, kwargs): + """ Save the model in ONNX format.""" + if batch_size is None: + raise ValueError("`batch_size` must be specified when saving in ONNX format!") + + inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False) + self.sanitize_module_names(self.model) + + if pickle_metadata: + name, ext = os.path.splitext(path) + if ext == ".onnx": + raise ValueError("Path should not have .onnx extension when saving with metadata!") + onnx_path = name + ".onnx" + else: + onnx_path = path - preserved_dict = {item: getattr(self, item) for item in preserved} - torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict}, - path, pickle_module=pickle_module, **kwargs) + torch.onnx.export( + self.model.eval(), + inputs, + onnx_path, + opset_version=opset_version, + ) - elif use_openvino: - import openvino as ov + if pickle_metadata: + preserved = self.PRESERVE_ONNX - ignore_attributes + meta = { + "onnx": True, + "path_onnx": onnx_path, + "onnx_batch_size": batch_size, + **{k: getattr(self, k) for k in preserved} + } + torch.save(meta, path, pickle_module=pickle_module, **kwargs) - path_openvino = path_openvino or (path + '_openvino') - if os.path.splitext(path_openvino)[-1] == '': - path_openvino = f'{path_openvino}.xml' - # Save model - model = self.model.eval() + def _save_openvino(self, path, pickle_metadata, batch_size, + ignore_attributes, pickle_module, kwargs): + """ Save the model in OpenVINO format. """ + import openvino as ov - if not isinstance(self.model, ov.Model): - inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False) - model = ov.convert_model(model, example_input=inputs) + model = self.model.eval() + if not isinstance(model, ov.Model): + inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False) + model = ov.convert_model(model, example_input=inputs) + + name, ext = os.path.splitext(path) + if pickle_metadata: + if ext == ".xml": + raise ValueError("Path should not have .xml extension when saving with metadata!") + openvino_path = name + ".xml" + else: + if ext != ".xml": + raise ValueError("Path should have .xml extension when saving OpenVINO model!") + openvino_path = path - ov.save_model(model, output_model=path_openvino) + ov.save_model(model, output_model=openvino_path) - # Save the rest of parameters + if pickle_metadata: preserved = self.PRESERVE_OPENVINO - ignore_attributes - preserved_dict = {item: getattr(self, item) for item in preserved} - torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict}, - path, pickle_module=pickle_module, **kwargs) + meta = { + "openvino": True, + "path_openvino": openvino_path, + **{k: getattr(self, k) for k in preserved} + } + torch.save(meta, path, pickle_module=pickle_module, **kwargs) + + + def _save_safetensors(self, path, pickle_metadata, + ignore_attributes, pickle_module, kwargs): + """ Save the model in Safetensors format.""" + from safetensors.torch import save_file + + if pickle_metadata: + name, ext = os.path.splitext(path) + if ext == ".safetensors": + raise ValueError("Path should not have .safetensors extension when saving with metadata!") + safetensors_path = os.path.splitext(path)[0] + ".safetensors" else: - preserved = set(self.PRESERVE) - set(ignore_attributes) - torch.save({item: getattr(self, item) for item in preserved}, - path, pickle_module=pickle_module, **kwargs) + safetensors_path = path + + save_file(self.model.state_dict(), safetensors_path) + + if pickle_metadata: + preserved = self.PRESERVE_SAFETENSORS - ignore_attributes + meta = { + "safetensors": True, + "path_safetensors": safetensors_path, + **{k: getattr(self, k) for k in preserved} + } + torch.save(meta, path, pickle_module=pickle_module, **kwargs) - def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): + @classmethod + def sanitize_module_names(cls, module): + """ + Recursively rename submodules to ensure names are safe for ONNX export. + Replaces spaces, quotes, and commas with underscores. + """ + # Work on a list of keys to avoid mutating dict while iterating + keys = list(module._modules.keys()) # noqa: SLF001 + for key in keys: + child = module._modules[key] # noqa: SLF001 + clean_key = key.replace(' ', '_').replace('"', '').replace(',', '_') + if clean_key != key: + module._modules[clean_key] = module._modules.pop(key) # noqa: SLF001 + # Recurse + cls.sanitize_module_names(child) + + def load(self, path, fmt=None, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): """ Load a torch model from a file. If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size @@ -1770,8 +1867,10 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, Parameters ---------- - file : str, PathLike, io.Bytes + path : str, PathLike, io.Bytes a file where a model is stored. + fmt: optional str + Weights format. Available formats: "pt", "onnx", "openvino", "safetensors" make_infrastructure : bool Whether to re-create model loss, optimizer, scaler and decay. mode : str @@ -1782,8 +1881,8 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, Other keyword arguments, passed directly to :func:`torch.save`. """ model_load_kwargs = kwargs.pop('model_load_kwargs', {}) - device = kwargs.pop('device', None) + pickle_module = dill if pickle_module is None else pickle_module if device is not None: self.device = device @@ -1793,10 +1892,21 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, else: self._parse_devices() + if isinstance(path, str): + if fmt == "safetensors" or (fmt is None and path.endswith(".safetensors")): + self._load_safetensors(path, make_infrastructure=make_infrastructure, mode=mode) + return + if fmt == "onnx" or (fmt is None and path.endswith(".onnx")): + self._load_onnx(path, make_infrastructure=make_infrastructure, mode=mode) + return + if fmt == "openvino" or (fmt is None and path.endswith(".xml")): + self._load_openvino(path, **model_load_kwargs) + return + kwargs['map_location'] = self.device # Load items from disk storage and set them as insance attributes - checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs) + checkpoint = torch.load(path, pickle_module=pickle_module, **kwargs) # `load_config` is a reference to `self.external_config` used to update `config` # It is required since `self.external_config` may be overwritten in the cycle below @@ -1808,32 +1918,63 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, if 'openvino' in checkpoint: # Load openvino model - model = OVModel(model_path=checkpoint['path_openvino'], **model_load_kwargs) - self.model = model + self._load_openvino(checkpoint['path_openvino'], **model_load_kwargs) + elif 'onnx' in checkpoint: + self._load_onnx(checkpoint['path_onnx'], microbatch_size=checkpoint['onnx_batch_size'], + **model_load_kwargs) + elif "safetensors" in checkpoint: + self._load_safetensors(checkpoint['path_safetensors'], make_infrastructure=make_infrastructure, mode=mode) + + def _load_onnx(self, file, make_infrastructure=False, mode='eval', microbatch_size=None): + """Load a model from ONNX file.""" + try: + from onnx2torch import convert + except ImportError as e: + raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e - self._loaded_from_openvino = True - self.disable_training = True + model = convert(file).eval() + self.model = model + if microbatch_size: + self.microbatch_size = microbatch_size - else: - # Load model from onnx, if needed - if 'onnx' in checkpoint: - try: - from onnx2torch import convert - except ImportError as e: - raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e - - model = convert(checkpoint['path_onnx']).eval() - self.model = model - self.microbatch_size = checkpoint['onnx_batch_size'] - self._loaded_from_onnx = True - self.disable_training = True + self.model_to_device() - self.model_to_device() + if make_infrastructure: + self.make_infrastructure() - if make_infrastructure: - self.make_infrastructure() + self.set_model_mode(mode) - self.set_model_mode(mode) + def _load_safetensors(self, file, make_infrastructure=False, mode='eval'): + """Load a model from Safetensors file.""" + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError('Loading model, stored in Safetensors format, requires `safetensors` library.') from e + + state_dict = load_file(file) + + inputs = self.make_placeholder_data(to_device=True) + with torch.no_grad(): + self.model = Network(inputs=inputs, config=self.config, device=self.device) + + self.model.load_state_dict(state_dict) + + self.model_to_device() + + if make_infrastructure: + self.make_infrastructure() + + self.set_model_mode(mode) + + return + + def _load_openvino(self, file, **model_load_kwargs): + """Load a model from OpenVINO file.""" + model = OVModel(model_path=file, **model_load_kwargs) + self.model = model + + self._loaded_from_openvino = True + self.disable_training = True # Utilities to use when working with TorchModel @@ -1856,7 +1997,6 @@ def get_model_reference(obj=None): return model_reference return None - # Debug and profile the performance def set_requires_grad(self, requires_grad): """ Set `requires_grad` flag for the underlying Pytorch model. diff --git a/batchflow/models/torch/base_mixins.py b/batchflow/models/torch/base_mixins.py index c8b7c2723..98a791da5 100644 --- a/batchflow/models/torch/base_mixins.py +++ b/batchflow/models/torch/base_mixins.py @@ -5,9 +5,12 @@ import numpy as np import torch -from ...plotter import plot from ...decorators import deprecated +from ...utils_import import try_import +plot = try_import(module='...plotter', package=__name__, attribute='plot', + help='Try `pip install batchflow[image]`!') + # Also imports `tensorboard`, if necessary diff --git a/batchflow/tests/model_save_load_test.py b/batchflow/tests/model_save_load_test.py index faf54da40..e350841d0 100644 --- a/batchflow/tests/model_save_load_test.py +++ b/batchflow/tests/model_save_load_test.py @@ -1,5 +1,6 @@ """ Test for model saving and loading """ +import os import pickle import pytest @@ -186,3 +187,49 @@ def test_bare_model(self, save_path, model_class, pickle_module, outputs): loaded_predictions = model_load.predict(*args, **kwargs) assert (np.concatenate(saved_predictions) == np.concatenate(loaded_predictions)).all() + + @pytest.mark.parametrize("fmt", [None, 'onnx', 'openvino', 'safetensors']) + @pytest.mark.parametrize("pickle_metadata", [False, True]) + def test_save_load_format(self, save_path, model_class, fmt, pickle_metadata): + num_classes = 10 + dataset_size = 10 + image_shape = (2, 100, 100) + + save_kwargs = { + None: {}, + 'onnx': dict(batch_size=dataset_size), + 'openvino': {}, + 'safetensors': {}, + } + load_kwargs = { + None: {}, + 'onnx': {}, + 'openvino': {'device': 'cpu'}, + 'safetensors': {}, + } + + if fmt == 'openvino' and not pickle_metadata: + save_path = os.path.splitext(save_path)[0] + '.xml' + + model_config = { + 'classes': num_classes, + 'inputs_shapes': image_shape, + 'output': 'sigmoid' + } + + model_save = model_class(config=model_config) + + batch_shape = (dataset_size, *image_shape) + images_array = np.random.random(batch_shape) + + inputs = images_array.astype('float32') + + saved_predictions = model_save.predict(inputs, outputs='sigmoid') + model_save.save(path=save_path, pickle_metadata=pickle_metadata, fmt=fmt, **save_kwargs[fmt]) + + load_config = {} if fmt != 'safetensors' else model_save.config + model_load = model_class(config=load_config) + model_load.load(path=save_path, fmt='pt' if pickle_metadata else fmt, **load_kwargs[fmt]) + loaded_predictions = model_load.predict(inputs, outputs='sigmoid') + + assert np.isclose(np.concatenate(saved_predictions), np.concatenate(loaded_predictions), atol=1e-3).all() diff --git a/batchflow/tests/research_test.py b/batchflow/tests/research_test.py index f1d36c793..1cf369771 100644 --- a/batchflow/tests/research_test.py +++ b/batchflow/tests/research_test.py @@ -570,6 +570,7 @@ def f(a): assert research.results.df.iloc[0].a == f(2) assert research.results.df.iloc[0].b == f(3) + @pytest.mark.slow @pytest.mark.parametrize('dump_results', [False, True]) @pytest.mark.parametrize('redirect_stdout', [True, 0, 1, 2, 3]) @pytest.mark.parametrize('redirect_stderr', [True, 0, 1, 2, 3]) diff --git a/pyproject.toml b/pyproject.toml index 6bd6d1cf6..b9715e717 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "numba>=0.56", "llvmlite", "scipy>=1.9", - "tqdm>=4.19" + "tqdm>=4.19", + "pytest>=8.3.4", ] [project.optional-dependencies] @@ -74,6 +75,19 @@ telegram = [ "pillow>=9.4,<11.0", ] +safetensors = [ + "safetensors>=0.5.3", +] + +onnx = [ + "onnx >=1.14.0", + "onnx2torch >= 1.5.0", +] + +openvino = [ + "openvino >= 2025.0.0", +] + other = [ "urllib3>=1.25" ]