-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Expose weights_only argument in Fabric.load #21470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Expose weights_only argument in Fabric.load #21470
Conversation
for more information, see https://pre-commit.ci
src/lightning/fabric/fabric.py
Outdated
| """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) How and which | ||
| processes load gets determined by the `strategy`. | ||
| How and which processes load gets determined by the `strategy`. | ||
| This method must be called on all processes! | ||
| Args: | ||
| path: A path to where the file is located. | ||
| state: A dictionary of objects whose state will be restored in-place from the checkpoint path. | ||
| If no state is given, then the checkpoint will be returned in full. | ||
| strict: Whether to enforce that the keys in `state` match the keys in the checkpoint. | ||
| Returns: | ||
| The remaining items that were not restored into the given state dictionary. If no state dictionary is | ||
| given, the full checkpoint will be returned. | ||
| Example:: | ||
| # Load full checkpoint | ||
| checkpoint = fabric.load("checkpoint.pth") | ||
| # Load into existing objects | ||
| state = {"model": model, "optimizer": optimizer} | ||
| remainder = fabric.load("checkpoint.pth", state) | ||
| epoch = remainder.get("epoch", 0) | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original docstring should be extended with the additional kwarg weight_only instead of removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve updated the docstring to preserve the original content and extended it with the weights_only argument, including fixing the example formatting. Thanks for the review!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…hal-17/pytorch-lightning into fabric-load-weights-only
for more information, see https://pre-commit.ci
| # We need to unwrap objects (see above) but this creates a new dictionary. In-place updates | ||
| # (for user metadata) wouldn't show up in the original dict, so we need to copy the data back. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments here should still be preserved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review 👍
I’ve resolved the merge conflict in Fabric.load, preserved all existing comments and examples, and kept the
weights_only change purely additive. Please let me know if everything looks good now.
…/CodeVishal-17/pytorch-lightning into fabric-load-weights-only" This reverts commit 6143e65, reversing changes made to debde18.
…hal-17/pytorch-lightning into fabric-load-weights-only
for more information, see https://pre-commit.ci
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21470 +/- ##
=======================================
- Coverage 87% 87% -0%
=======================================
Files 270 270
Lines 24059 24059
=======================================
- Hits 20862 20859 -3
- Misses 3197 3200 +3 |
What does this PR do?
Exposes the
weights_onlyargument inFabric.load()and forwards it to the underlying strategy checkpoint loading call.This restores compatibility with mixed-object checkpoints after the PyTorch 2.6 change where
torch.loaddefaults toweights_only=True.Motivation and context
With PyTorch 2.6,
torch.loaddefaults toweights_only=True, which prevents loading checkpoints that contain non-tensor objects (such as optimizers, schedulers, or metadata).While
strategy.load_checkpoint(..., weights_only=False)already supports this use case,Fabric.load()did not expose the argument, making it impossible for users to opt out of the new default behavior when using Fabric.This PR adds a keyword-only
weights_onlyargument toFabric.load()without changing the existing default behavior.Fixes #21459
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21470.org.readthedocs.build/en/21470/