-
Notifications
You must be signed in to change notification settings - Fork 45
Add save in safetensors format #784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @igor-iusupov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here to provide a summary of this pull request. This PR introduces the capability to save PyTorch models in the safetensors format within the batchflow library. It achieves this by adding a new boolean parameter, use_safetensors, to the existing save method of the base PyTorch model class. When this parameter is set to True, the method now uses the safetensors.torch.save_file function to save the model's state dictionary, offering an alternative to the default PyTorch pickle format or other supported formats like ONNX and OpenVINO.
Highlights
- Safetensors Support: Adds the ability to save PyTorch models using the safetensors format.
- Model Saving: Modifies the base model's
savemethod to include an option for safetensors output.
Changelog
- batchflow/models/torch/base.py
- Added the
use_safetensorsboolean parameter to thesavemethod signature (around line 1671). - Updated the docstring for the
savemethod to document the newuse_safetensorsparameter (around line 1691). - Implemented the logic within the
savemethod to handle saving the model's state dictionary usingsafetensors.torch.save_filewhenuse_safetensorsis True (around lines 1761-1764). - Removed a blank line (around line 17).
- Added the
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the capability to save PyTorch models in the safetensors format, which is a valuable addition for model interoperability and security. The core implementation of saving the state_dict using safetensors.torch.save_file is correct.
My review focuses on improving clarity for users regarding how this new option interacts with existing save functionalities, particularly concerning metadata preservation, the usage of the path argument, and the handling of mutually exclusive format flags. Addressing these points will enhance the robustness and usability of this feature.
Summary of Findings
- Documentation Clarity for
safetensorsSave Option: The method docstring should be enhanced to clearly explain: 1) How thepathargument is utilized whenuse_safetensors=True. 2) The mutual exclusivity and precedence if multipleuse_...format flags are enabled. 3) Whether associated metadata (like model configuration, training iteration) is saved with thesafetensorsformat, as it currently appears to save only thestate_dict. - Metadata Preservation Consistency with
safetensors: The currentsafetensorsimplementation only saves the model'sstate_dict, which differs from the ONNX and OpenVINO saving options that also store other model attributes (e.g., config, iteration count). This inconsistency could lead to loss of information if not intended. Thesafetensors.torch.save_filefunction supports ametadataargument, which could be used to store these attributes for consistency. If this omission is by design, it needs to be prominently documented. - Handling of Mutually Exclusive Save Format Flags: The
savemethod's behavior when multiple format flags (use_onnx,use_openvino,use_safetensors) are set toTrueshould be more robust or clearly documented. Currently, only the first format encountered in theif/elifconditional chain is used, which might not be obvious to users. Consider raising an error or explicitly documenting the priority.
Merge Readiness
This pull request introduces a useful feature for saving models in the safetensors format. However, to ensure clarity for users and consistency with existing functionalities, I recommend addressing the points raised, particularly regarding metadata preservation and the handling/documentation of mutually exclusive save options. The identified high severity issue concerning metadata should be carefully considered. Once these aspects are clarified and potentially adjusted, the PR will be in a much stronger position for merging. As I am not authorized to approve pull requests, further review and approval by authorized maintainers will be necessary.
batchflow/models/torch/base.py
Outdated
| elif use_safetensors: | ||
| from safetensors.torch import save_file | ||
| state_dict = self.model.state_dict() | ||
| save_file(state_dict, path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new elif block correctly implements saving to safetensors. There are a couple of points to consider for enhancing robustness and consistency with other save formats:
-
Handling Multiple Format Flags: If a user inadvertently sets multiple format flags to
True(e.g.,use_onnx=Trueanduse_safetensors=True), only the first one encountered in theif/elifchain will be executed. This could be non-obvious. Would it be more user-friendly to:- Raise a
ValueErrorif more than one ofuse_onnx,use_openvino,use_safetensorsisTrue? - Or, at least, clearly document the order of precedence in the method's docstring?
- Raise a
-
Metadata Preservation Consistency: The ONNX and OpenVINO saving paths preserve a dictionary of model attributes (derived from
self.PRESERVE_ONNXandself.PRESERVE_OPENVINO) alongside the model weights. The currentsafetensorsimplementation, however, only saves themodel.state_dict().
Is this omission of metadata for thesafetensorsformat an intentional design choice? If users rely on reloading models with their full context (e.g.,full_config,iterationcount, which are part ofself.PRESERVE), this difference in behavior could be critical.
Thesafetensors.torch.save_filefunction supports ametadata: Dict[str, str]argument. To maintain consistency, it would be beneficial to also save relevant preserved attributes when usingsafetensors. This would involve:- Defining a set of attributes to preserve for
safetensors(e.g.,self.PRESERVE_SAFETENSORSor using a subset ofself.PRESERVE). - Serializing these attributes into a string-to-string dictionary compatible with
safetensorsmetadata.
For example:
# Inside the `elif use_safetensors:` block: from safetensors.torch import save_file import json # For serializing complex objects state_dict = self.model.state_dict() # Example: Prepare metadata (attributes and serialization logic need careful consideration) # Note: safetensors metadata values must be strings. ignore_attributes = ignore_attributes or set() preserved_attrs_for_sf = set(self.PRESERVE) - ignore_attributes # Or a dedicated PRESERVE_SAFETENSORS metadata_dict = {} for item_key in preserved_attrs_for_sf: if hasattr(self, item_key): attr_val = getattr(self, item_key) if isinstance(attr_val, (str, int, float, bool)): metadata_dict[item_key] = str(attr_val) elif isinstance(attr_val, (Config, dict, list)): try: # Attempt to JSON serialize complex types metadata_dict[item_key] = json.dumps(attr_val.config if isinstance(attr_val, Config) else attr_val) except TypeError: # Fallback for non-serializable types metadata_dict[item_key] = f"Unserializable: <{type(attr_val).__name__}>" # else: might log a warning for attributes not easily convertible to string metadata save_file(state_dict, path, metadata=metadata_dict)
If omitting metadata is a deliberate choice (e.g., for a minimal, weights-only save), this distinction should be very clearly documented.
- Defining a set of attributes to preserve for
batchflow/models/torch/base.py
Outdated
| use_safetensors: bool | ||
| Whether to store model as safetensors file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for use_safetensors clearly states its purpose. To further improve user understanding of the save method with this new option, could we expand the documentation (either here or in the main method docstring) to cover:
-
pathArgument Behavior: Clarify that whenuse_safetensors=True, the mainpathargument is directly used as the filename for the.safetensorsfile. This is a key difference from the default behavior wherepathmight point to a PyTorch pickle containing metadata, or frompath_onnx/path_openvinowhich are separate arguments. -
Mutual Exclusivity of Formats: Explain how
use_onnx,use_openvino, anduse_safetensorsinteract if multiple are set toTrue. The currentif/elif/elsestructure implies mutual exclusivity with a specific order of precedence (ONNX > OpenVINO > SafeTensors > default PyTorch save). Making this explicit would prevent user confusion. -
Metadata with
safetensors: It would be helpful to document whether metadata (attributes fromself.PRESERVElikefull_config,iteration, etc.) is saved whenuse_safetensors=True. The current implementation appears to save only thestate_dict, which differs from the ONNX and OpenVINO options that store a dictionary including such metadata. Highlighting this difference is important.
|
Try to run |
|
Fix linting issues |
1805059 to
028d383
Compare
batchflow/models/torch/base.py
Outdated
| model = convert(file).eval() | ||
| self.model = model | ||
|
|
||
| self.model_to_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that call, model.load('tmp.safetensors', fmt='safetensors', pickle_metadata=False, device='cpu') create model on cuda
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Alexey Kozhevin <[email protected]>
1296356 to
169f58a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for saving models in the safetensors format while also refactoring the save/load methods for various model formats. Key changes include adding the safetensors dependency in pyproject.toml, tagging a slow test in research_test.py, and updating the save and load methods in the torch base model to accommodate new format options.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| pyproject.toml | Adds safetensors dependency |
| batchflow/tests/research_test.py | Adds @pytest.mark.slow decorator for a test case |
| batchflow/models/torch/base.py | Refactors model saving/loading to support "pt", "onnx", "openvino", and "safetensors" |
Comments suppressed due to low confidence (1)
batchflow/models/torch/base.py:1859
- The variable 'model_load_kwargs' is used without being defined or passed. Please define it or adjust the parameters to avoid a runtime error.
model = OVModel(model_path=file, **model_load_kwargs)
No description provided.