-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
Conversation
…ameter loading Fixes Lightning-AI#21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged).
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a public adapt_checkpoint_hparams() hook to LightningCLI that enables users to customize hyperparameters loaded from checkpoints before model instantiation. This addresses the issue of loading checkpoints across different module classes (e.g., from TrainingModule to InferenceModule) where incompatible __init__ parameters would otherwise cause failures.
Key Changes:
- Added
adapt_checkpoint_hparams()public method with comprehensive documentation - Integrated the hook into
_parse_ckpt_path()to allow customization before hyperparameter application - Maintained backward compatibility with a default no-op implementation
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/lightning/pytorch/cli.py
Outdated
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | ||
| """Adapt checkpoint hyperparameters before instantiating the model class. | ||
| This method allows for customization of hyperparameters loaded from a checkpoint when | ||
| using a different model class than the one used for training. For example, when loading | ||
| a checkpoint from a TrainingModule to use with an InferenceModule that has different | ||
| ``__init__`` parameters, you can remove or modify incompatible hyperparameters. | ||
| Args: | ||
| checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. | ||
| Returns: | ||
| Dictionary of adapted hyperparameters to be used for model instantiation. | ||
| Example:: | ||
| class MyCLI(LightningCLI): | ||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | ||
| # Remove training-specific hyperparameters not needed for inference | ||
| checkpoint_hparams.pop("lr", None) | ||
| checkpoint_hparams.pop("weight_decay", None) | ||
| return checkpoint_hparams | ||
| Note: | ||
| If subclass module mode is enabled and ``_class_path`` is present in the checkpoint | ||
| hyperparameters, you may need to modify it as well to point to your new module class. | ||
| """ | ||
| return checkpoint_hparams |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new adapt_checkpoint_hparams() hook lacks test coverage. Given that tests/tests_pytorch/test_cli.py contains comprehensive tests for checkpoint loading functionality (e.g., test_lightning_cli_ckpt_path_argument_hparams and test_lightning_cli_ckpt_path_argument_hparams_subclass_mode), tests should be added to verify:
- The hook is called when loading checkpoint hyperparameters
- Modifications made in the hook are applied correctly
- Returning an empty dict properly skips checkpoint hyperparameter loading
- The hook works in both regular and subclass modes
src/lightning/pytorch/cli.py
Outdated
| else: | ||
| self.config = parser.parse_args(args) | ||
|
|
||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.
src/lightning/pytorch/cli.py
Outdated
| Example:: | ||
| class MyCLI(LightningCLI): | ||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.
mauvilsa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is looking good. However, the subcommand parameter is missing. Also please add unit tests.
src/lightning/pytorch/cli.py
Outdated
| else: | ||
| self.config = parser.parse_args(args) | ||
|
|
||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | |
| def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
As mentioned in my proposal, the method should receive a subcommand parameter.
src/lightning/pytorch/cli.py
Outdated
| checkpoint_hparams.pop("lr", None) | ||
| checkpoint_hparams.pop("weight_decay", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this example, removing lr and weight_decay should not be done if the subcommand is fit.
src/lightning/pytorch/cli.py
Outdated
| return | ||
|
|
||
| # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook | ||
| hparams = self.adapt_checkpoint_hparams(hparams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| hparams = self.adapt_checkpoint_hparams(hparams) | |
| hparams = self.adapt_checkpoint_hparams(subcommand, hparams) |
…ook and add tests - Update adapt_checkpoint_hparams signature to include subcommand parameter allowing context-aware customization of checkpoint hyperparameters - Change type annotations to use lowercase dict (Python 3.9+ style) - Update docstring with subcommand parameter documentation - Add example showing conditional logic based on subcommand - Add comprehensive unit tests: - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading - Tests cover both regular and subclass modes
for more information, see https://pre-commit.ci
|
Thanks for the response. I already updated Also added def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
if subcommand != "fit":
checkpoint_hparams.pop("lr", None) # Remove training params for inference
return checkpoint_hparamsI also included 2 comprehensive tests:
|
- Split method signature across multiple lines to stay within 120 char limit - Improves code readability in documentation example
mauvilsa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is looking good. But the two tests fail. You will need to implement a new Model class for these tests.
| assert cli.model.layer.out_features == 4 | ||
|
|
||
|
|
||
| def test_adapt_checkpoint_hparams_hook(cleandir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def test_adapt_checkpoint_hparams_hook(cleandir): | |
| def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir): |
| def add_arguments_to_parser(self, parser): | ||
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def add_arguments_to_parser(self, parser): | |
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) |
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.
| def add_arguments_to_parser(self, parser): | ||
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def add_arguments_to_parser(self, parser): | |
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) |
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.
| # First, create a checkpoint | ||
| cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] | ||
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||
| cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test fails because of BoringCkptPathModel has a module torch.nn.Linear(32, out_dim). If the out_dim is changed, then there is a tensor size mismatch.
Instead of using BoringCkptPathModel, implement a new class for these two tests, that just sets an attribute that can be asserted after instantiation.
What does this PR do?
Fixes #21255
This PR adds a public
adapt_checkpoint_hparams()hook toLightningCLIthat allows users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This solves the problem of loading checkpoints across different module classes (e.g., fromTrainingModuletoInferenceModule).Problem
When using
LightningCLIwith checkpoints, hyperparameters saved during training are automatically loaded and applied when running other subcommands (test, predict, etc.). This is convenient when using the same module class, but fails when using a different class with incompatible__init__parameters.Example scenario:
Running
cli predict --ckpt_path checkpoint.ckptwithInferenceModulefails because the CLI tries to passlr=1e-3from the checkpoint toInferenceModule.__init__().Solution
Added
adapt_checkpoint_hparams()public method that users can override to customize loaded hyperparameters:Implementation Details
adapt_checkpoint_hparams()public method inLightningCLI_parse_ckpt_path()to call the hook after loading but before applying hyperparametersWhy This Approach?
As discussed in #21255, this is superior to alternatives:
hidden_dim)Testing
The implementation:
_class_pathmodification when neededExample Use Cases
Remove training-only parameters:
Change module class in subclass mode:
Disable all checkpoint hyperparameters:
Does your PR introduce any breaking changes?
No, this is a purely additive change. The default implementation returns hyperparameters unchanged, preserving existing behavior.
Before submitting
PR review
cc: @mauvilsa @ziw-liu
📚 Documentation preview 📚: https://pytorch-lightning--21408.org.readthedocs.build/en/21408/