From 910a712a07cae43fade20208a7ee0fbd2dcc6218 Mon Sep 17 00:00:00 2001 From: arrdel Date: Fri, 5 Dec 2025 21:58:11 -0500 Subject: [PATCH 1/5] Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading Fixes #21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged). --- src/lightning/pytorch/cli.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2bcb1d8f4b1fd..87b3634dda413 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -560,6 +560,36 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: + """Adapt checkpoint hyperparameters before instantiating the model class. + + This method allows for customization of hyperparameters loaded from a checkpoint when + using a different model class than the one used for training. For example, when loading + a checkpoint from a TrainingModule to use with an InferenceModule that has different + ``__init__`` parameters, you can remove or modify incompatible hyperparameters. + + Args: + checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. + + Returns: + Dictionary of adapted hyperparameters to be used for model instantiation. + + Example:: + + class MyCLI(LightningCLI): + def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: + # Remove training-specific hyperparameters not needed for inference + checkpoint_hparams.pop("lr", None) + checkpoint_hparams.pop("weight_decay", None) + return checkpoint_hparams + + Note: + If subclass module mode is enabled and ``_class_path`` is present in the checkpoint + hyperparameters, you may need to modify it as well to point to your new module class. + + """ + return checkpoint_hparams + def _parse_ckpt_path(self) -> None: """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" if not self.config.get("subcommand"): @@ -571,6 +601,12 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return + + # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook + hparams = self.adapt_checkpoint_hparams(hparams) + if not hparams: + return + if "_class_path" in hparams: hparams = { "class_path": hparams.pop("_class_path"), From ad1a0285dfa3d2b9fcb98ee725d14d5ff63fd6fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 02:59:15 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 87b3634dda413..2e2bc939705ef 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -601,12 +601,12 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return - + # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook hparams = self.adapt_checkpoint_hparams(hparams) if not hparams: return - + if "_class_path" in hparams: hparams = { "class_path": hparams.pop("_class_path"), From 998ea3e6968613e4d2a05040749678841b18137e Mon Sep 17 00:00:00 2001 From: arrdel Date: Tue, 9 Dec 2025 10:39:30 -0500 Subject: [PATCH 3/5] refactor(cli): Add subcommand parameter to adapt_checkpoint_hparams hook and add tests - Update adapt_checkpoint_hparams signature to include subcommand parameter allowing context-aware customization of checkpoint hyperparameters - Change type annotations to use lowercase dict (Python 3.9+ style) - Update docstring with subcommand parameter documentation - Add example showing conditional logic based on subcommand - Add comprehensive unit tests: - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading - Tests cover both regular and subclass modes --- src/lightning/pytorch/cli.py | 15 +++++---- tests/tests_pytorch/test_cli.py | 60 +++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2e2bc939705ef..d897e39660066 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -560,7 +560,7 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) - def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: + def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]: """Adapt checkpoint hyperparameters before instantiating the model class. This method allows for customization of hyperparameters loaded from a checkpoint when @@ -569,6 +569,8 @@ def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[s ``__init__`` parameters, you can remove or modify incompatible hyperparameters. Args: + subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict'). + This allows you to apply different hyperparameter adaptations depending on the context. checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. Returns: @@ -577,10 +579,11 @@ def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[s Example:: class MyCLI(LightningCLI): - def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: - # Remove training-specific hyperparameters not needed for inference - checkpoint_hparams.pop("lr", None) - checkpoint_hparams.pop("weight_decay", None) + def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]: + # Only remove training-specific hyperparameters for non-fit subcommands + if subcommand != "fit": + checkpoint_hparams.pop("lr", None) + checkpoint_hparams.pop("weight_decay", None) return checkpoint_hparams Note: @@ -603,7 +606,7 @@ def _parse_ckpt_path(self) -> None: return # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook - hparams = self.adapt_checkpoint_hparams(hparams) + hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams) if not hparams: return diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 6ff4bee264a7b..fe14d3de920cd 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -562,6 +562,66 @@ def add_arguments_to_parser(self, parser): assert cli.model.layer.out_features == 4 +def test_adapt_checkpoint_hparams_hook(cleandir): + """Test that the adapt_checkpoint_hparams hook is called and modifications are applied.""" + class AdaptHparamsCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) + + def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams): + """Remove out_dim and hidden_dim for non-fit subcommands.""" + if subcommand != "fit": + checkpoint_hparams.pop("out_dim", None) + checkpoint_hparams.pop("hidden_dim", None) + return checkpoint_hparams + + # First, create a checkpoint by running fit + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsCLI(BoringCkptPathModel) + + assert cli.config.fit.model.out_dim == 3 + assert cli.config.fit.model.hidden_dim == 6 + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + + # Test that predict uses adapted hparams (without out_dim and hidden_dim) + cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsCLI(BoringCkptPathModel) + + # Since we removed out_dim and hidden_dim for predict, the CLI values should be used + assert cli.config.predict.model.out_dim == 5 + assert cli.config.predict.model.hidden_dim == 10 + + +def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir): + """Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading.""" + class AdaptHparamsEmptyCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) + + def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams): + """Disable checkpoint hyperparameter loading.""" + return {} + + # First, create a checkpoint + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + + # Test that predict uses default values when hook returns empty dict + cli_args = ["predict", f"--ckpt_path={checkpoint_path}"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) + + # Model should use default values (out_dim=8, hidden_dim=16) + assert cli.config_init.predict.model.out_dim == 8 + assert cli.config_init.predict.model.hidden_dim == 16 + + def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): From 00e7032f9c8c2d62b6b7552db5223cd4b173fa97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 15:40:17 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/test_cli.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index fe14d3de920cd..e044e7727b127 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -564,6 +564,7 @@ def add_arguments_to_parser(self, parser): def test_adapt_checkpoint_hparams_hook(cleandir): """Test that the adapt_checkpoint_hparams hook is called and modifications are applied.""" + class AdaptHparamsCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) @@ -597,6 +598,7 @@ def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams): def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir): """Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading.""" + class AdaptHparamsEmptyCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) From b3b10252e610776c1cfabdde3b899a387dfa7204 Mon Sep 17 00:00:00 2001 From: arrdel Date: Tue, 9 Dec 2025 13:18:50 -0500 Subject: [PATCH 5/5] fix: Break long line in adapt_checkpoint_hparams docstring example - Split method signature across multiple lines to stay within 120 char limit - Improves code readability in documentation example --- src/lightning/pytorch/cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index d897e39660066..9226b1dfab02f 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -579,7 +579,9 @@ def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str Example:: class MyCLI(LightningCLI): - def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]: + def adapt_checkpoint_hparams( + self, subcommand: str, checkpoint_hparams: dict[str, Any] + ) -> dict[str, Any]: # Only remove training-specific hyperparameters for non-fit subcommands if subcommand != "fit": checkpoint_hparams.pop("lr", None)