Skip to content

Commit 910a712

Browse files
committed
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading
Fixes #21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged).
1 parent 79ffe50 commit 910a712

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

src/lightning/pytorch/cli.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,36 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
560560
else:
561561
self.config = parser.parse_args(args)
562562

563+
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
564+
"""Adapt checkpoint hyperparameters before instantiating the model class.
565+
566+
This method allows for customization of hyperparameters loaded from a checkpoint when
567+
using a different model class than the one used for training. For example, when loading
568+
a checkpoint from a TrainingModule to use with an InferenceModule that has different
569+
``__init__`` parameters, you can remove or modify incompatible hyperparameters.
570+
571+
Args:
572+
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.
573+
574+
Returns:
575+
Dictionary of adapted hyperparameters to be used for model instantiation.
576+
577+
Example::
578+
579+
class MyCLI(LightningCLI):
580+
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
581+
# Remove training-specific hyperparameters not needed for inference
582+
checkpoint_hparams.pop("lr", None)
583+
checkpoint_hparams.pop("weight_decay", None)
584+
return checkpoint_hparams
585+
586+
Note:
587+
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
588+
hyperparameters, you may need to modify it as well to point to your new module class.
589+
590+
"""
591+
return checkpoint_hparams
592+
563593
def _parse_ckpt_path(self) -> None:
564594
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
565595
if not self.config.get("subcommand"):
@@ -571,6 +601,12 @@ def _parse_ckpt_path(self) -> None:
571601
hparams.pop("_instantiator", None)
572602
if not hparams:
573603
return
604+
605+
# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
606+
hparams = self.adapt_checkpoint_hparams(hparams)
607+
if not hparams:
608+
return
609+
574610
if "_class_path" in hparams:
575611
hparams = {
576612
"class_path": hparams.pop("_class_path"),

0 commit comments

Comments
 (0)