-
Notifications
You must be signed in to change notification settings - Fork 45
Add save in safetensors format #784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
8a2fc15
Add save in safetensors format
bdf275a
Add safetensors to reqs
b000c74
Add loading safetensors
a86d4e4
Add logic for onnx weights
1192a21
Update load
d0de633
Update save method
e3f15db
Add pickle_metadata flag
ba767b0
Add return for onnx load
f7cc818
Add preserved_dict for safetensors
d5bba46
Update batchflow/models/torch/base.py
igor-iusupov e515c65
Update saving extensions
b491290
Update load model for safetensors
3d9392c
Small update
5304560
Update path logic. Make path optional
7305ea9
Fix for linter
7b5217a
Remove all use_xxx and add format=name
9476e3c
AssertError -> ValueError
88f7378
Format -> fmt
23d4d94
Add fmt to load
fcbb4b7
Update saving model name
5b0f63d
Small update
53d6cc9
Update batchflow/models/torch/base.py
igor-iusupov 19b135d
Update docstring
8ed42de
Remove device from load_file func
ed57f96
Make fmt optional
169f58a
Update model init
429d462
Fix model_to_device
AlexeyKozhevin f632dea
Mark test_redirect_stdout as slow
AlexeyKozhevin 2a8fce9
Refactoring
AlexeyKozhevin c298bd1
Fix names collisions
AlexeyKozhevin eb306ba
Update tests and load/save
AlexeyKozhevin be86d99
Change logic
AlexeyKozhevin 739af96
Minor changes
AlexeyKozhevin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,6 @@ | |
| from torch import nn | ||
| from torch.optim.swa_utils import AveragedModel, SWALR | ||
|
|
||
|
|
||
| from sklearn.decomposition import PCA | ||
|
|
||
| from ...utils_import import make_delayed_import | ||
|
|
@@ -400,6 +399,7 @@ def callable_init(module): # example of a callable for init | |
|
|
||
| PRESERVE_ONNX = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) | ||
| PRESERVE_OPENVINO = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) | ||
| PRESERVE_SAFETENSORS = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) | ||
|
|
||
| def __init__(self, config=None): | ||
| if not isinstance(config, (dict, Config)): | ||
|
|
@@ -987,12 +987,16 @@ def transfer_from_device(self, data, force_float32_dtype=True): | |
|
|
||
| def model_to_device(self, model=None): | ||
| """ Put model on device(s). If needed, apply DataParallel wrapper. """ | ||
| model = model if model is not None else self.model | ||
| model_ = model if model is not None else self.model | ||
|
|
||
| if len(self.devices) > 1: | ||
| model = nn.DataParallel(model, self.devices) | ||
| model_ = nn.DataParallel(model_, self.devices) | ||
| else: | ||
| model = model.to(self.device) | ||
| model_ = model_.to(self.device) | ||
|
|
||
| if model is None: | ||
| self.model = model_ | ||
| return model_ | ||
|
|
||
|
|
||
| # Apply model to train/predict on given data | ||
|
|
@@ -1669,7 +1673,7 @@ def convert_outputs(self, outputs): | |
|
|
||
|
|
||
| # Store model | ||
| def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None, | ||
| def save(self, path, fmt=None, pickle_metadata=True, | ||
| batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs): | ||
| """ Save underlying PyTorch model along with meta parameters (config, device spec, etc). | ||
|
|
||
|
|
@@ -1682,17 +1686,10 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op | |
| ---------- | ||
| path : str | ||
| Path to a file where the model data will be stored. | ||
| use_onnx: bool | ||
| Whether to store model in ONNX format. | ||
| path_onnx : str, optional | ||
| Used only if `use_onnx` is True. | ||
| If provided, then path to store the ONNX model; default `path_onnx` is `path` with '_onnx' postfix. | ||
| use_openvino: bool | ||
| Whether to store model as openvino xml file. | ||
| path_openvino : str, optional | ||
| Used only if `use_openvino` is True. | ||
| If provided, then path to store the openvino model; default `path_openvino` is `path` with '_openvino' | ||
| postfix. | ||
| fmt: Optional[str] | ||
| Weights format. Available formats: "pt", "onnx", "openvino", "safetensors" | ||
| pickle_metadata: bool | ||
| Whether make pickle with metadata | ||
| batch_size : int, optional | ||
| Used only if `use_onnx` is True. | ||
| Fixed batch size of the ONNX module. This is the only viable batch size for this model after loading. | ||
|
|
@@ -1706,6 +1703,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op | |
| kwargs : dict | ||
| Other keyword arguments, passed directly to :func:`torch.save`. | ||
| """ | ||
| available_formats = ("pt", "onnx", "openvino", "safetensors") | ||
|
|
||
| if fmt is None: | ||
| fmt = os.path.splitext(path)[-1][1:] | ||
|
|
||
| if fmt not in available_formats: | ||
| raise ValueError(f"fmt must be in {available_formats}") | ||
|
|
||
| dirname = os.path.dirname(path) | ||
| if dirname and not os.path.exists(dirname): | ||
| os.makedirs(dirname) | ||
|
|
@@ -1720,25 +1725,27 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op | |
| ignore_attributes = [] | ||
| ignore_attributes = set(ignore_attributes) | ||
|
|
||
| if use_onnx: | ||
| if fmt == "onnx": | ||
| if batch_size is None: | ||
| raise ValueError('Specify valid `batch_size`, used for model inference!') | ||
|
|
||
| inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False) | ||
| path_onnx = path_onnx or (path + '_onnx') | ||
|
|
||
| path_onnx = path if not pickle_metadata else os.path.splitext(path)[0] + ".onnx" | ||
| torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version) | ||
|
|
||
| # Save the rest of parameters | ||
| preserved = self.PRESERVE_ONNX - ignore_attributes | ||
| if pickle_metadata: | ||
| # Save the rest of parameters | ||
| preserved = self.PRESERVE_ONNX - ignore_attributes | ||
|
|
||
| preserved_dict = {item: getattr(self, item) for item in preserved} | ||
| torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict}, | ||
| path, pickle_module=pickle_module, **kwargs) | ||
| preserved_dict = {item: getattr(self, item) for item in preserved} | ||
| torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict}, | ||
| path, pickle_module=pickle_module, **kwargs) | ||
|
|
||
| elif use_openvino: | ||
| elif fmt == "openvino": | ||
| import openvino as ov | ||
|
|
||
| path_openvino = path_openvino or (path + '_openvino') | ||
| path_openvino = path if not pickle_metadata else os.path.splitext(path)[0] + ".openvino" | ||
| if os.path.splitext(path_openvino)[-1] == '': | ||
| path_openvino = f'{path_openvino}.xml' | ||
|
|
||
|
|
@@ -1751,18 +1758,33 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op | |
|
|
||
| ov.save_model(model, output_model=path_openvino) | ||
|
|
||
| # Save the rest of parameters | ||
| preserved = self.PRESERVE_OPENVINO - ignore_attributes | ||
| if pickle_metadata: | ||
| # Save the rest of parameters | ||
| preserved = self.PRESERVE_OPENVINO - ignore_attributes | ||
| preserved_dict = {item: getattr(self, item) for item in preserved} | ||
| torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict}, | ||
| path, pickle_module=pickle_module, **kwargs) | ||
|
|
||
| elif fmt == "safetensors": | ||
| from safetensors.torch import save_file | ||
| state_dict = self.model.state_dict() | ||
|
|
||
| path_safetensors = path if not pickle_metadata else os.path.splitext(path)[0] + ".safetensors" | ||
| save_file(state_dict, path_safetensors) | ||
igor-iusupov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| preserved = self.PRESERVE_SAFETENSORS - ignore_attributes | ||
| preserved_dict = {item: getattr(self, item) for item in preserved} | ||
| torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict}, | ||
| path, pickle_module=pickle_module, **kwargs) | ||
|
|
||
| if pickle_metadata: | ||
| torch.save({'safetensors': True, 'path_safetensors': path_safetensors, **preserved_dict}, | ||
| path, pickle_module=pickle_module, **kwargs) | ||
igor-iusupov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| else: | ||
| preserved = set(self.PRESERVE) - set(ignore_attributes) | ||
| torch.save({item: getattr(self, item) for item in preserved}, | ||
| path, pickle_module=pickle_module, **kwargs) | ||
|
|
||
| def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): | ||
| def load(self, file, fmt=None, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): | ||
| """ Load a torch model from a file. | ||
|
|
||
| If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size | ||
|
|
@@ -1772,6 +1794,8 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, | |
| ---------- | ||
| file : str, PathLike, io.Bytes | ||
| a file where a model is stored. | ||
| fmt: optional str | ||
| Weights format. Available formats: "pt", "onnx", "openvino", "safetensors" | ||
| make_infrastructure : bool | ||
| Whether to re-create model loss, optimizer, scaler and decay. | ||
| mode : str | ||
|
|
@@ -1793,6 +1817,53 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, | |
| else: | ||
| self._parse_devices() | ||
|
|
||
| if isinstance(file, str): | ||
| if fmt == "safetensors" or (fmt is None and file.endswith(".safetensors")): | ||
| from safetensors.torch import load_file | ||
| state_dict = load_file(file) | ||
|
|
||
| inputs = self.make_placeholder_data(to_device=True) | ||
| with torch.no_grad(): | ||
igor-iusupov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.model = Network(inputs=inputs, config=self.config, device=self.device) | ||
|
|
||
| self.model.load_state_dict(state_dict) | ||
|
|
||
| self.model_to_device() | ||
|
|
||
| if make_infrastructure: | ||
| self.make_infrastructure() | ||
|
|
||
| self.set_model_mode(mode) | ||
|
|
||
| return | ||
|
|
||
| if fmt == "onnx" or (fmt is None and file.endswith(".onnx")): | ||
| try: | ||
| from onnx2torch import convert | ||
| except ImportError as e: | ||
| raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e | ||
|
|
||
| model = convert(file).eval() | ||
| self.model = model | ||
|
|
||
| self.model_to_device() | ||
|
||
|
|
||
| if make_infrastructure: | ||
| self.make_infrastructure() | ||
|
|
||
| self.set_model_mode(mode) | ||
igor-iusupov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return | ||
|
|
||
| if fmt == "openvino" or (fmt is None and file.endswith(".openvino")): | ||
| model = OVModel(model_path=file, **model_load_kwargs) | ||
| self.model = model | ||
|
|
||
| self._loaded_from_openvino = True | ||
| self.disable_training = True | ||
|
|
||
| return | ||
|
|
||
| kwargs['map_location'] = self.device | ||
|
|
||
| # Load items from disk storage and set them as insance attributes | ||
|
|
@@ -1828,6 +1899,11 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, | |
| self._loaded_from_onnx = True | ||
| self.disable_training = True | ||
|
|
||
| if "safetensors" in checkpoint: | ||
| from safetensors.torch import load_file | ||
| state_dict = load_file(checkpoint['path_safetensors'], device=device) | ||
| self.model.load_state_dict(state_dict) | ||
igor-iusupov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.model_to_device() | ||
|
|
||
| if make_infrastructure: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.