diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2bcb1d8f4b1fd..9226b1dfab02f 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -560,6 +560,41 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]: + """Adapt checkpoint hyperparameters before instantiating the model class. + + This method allows for customization of hyperparameters loaded from a checkpoint when + using a different model class than the one used for training. For example, when loading + a checkpoint from a TrainingModule to use with an InferenceModule that has different + ``__init__`` parameters, you can remove or modify incompatible hyperparameters. + + Args: + subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict'). + This allows you to apply different hyperparameter adaptations depending on the context. + checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. + + Returns: + Dictionary of adapted hyperparameters to be used for model instantiation. + + Example:: + + class MyCLI(LightningCLI): + def adapt_checkpoint_hparams( + self, subcommand: str, checkpoint_hparams: dict[str, Any] + ) -> dict[str, Any]: + # Only remove training-specific hyperparameters for non-fit subcommands + if subcommand != "fit": + checkpoint_hparams.pop("lr", None) + checkpoint_hparams.pop("weight_decay", None) + return checkpoint_hparams + + Note: + If subclass module mode is enabled and ``_class_path`` is present in the checkpoint + hyperparameters, you may need to modify it as well to point to your new module class. + + """ + return checkpoint_hparams + def _parse_ckpt_path(self) -> None: """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" if not self.config.get("subcommand"): @@ -571,6 +606,12 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return + + # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook + hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams) + if not hparams: + return + if "_class_path" in hparams: hparams = { "class_path": hparams.pop("_class_path"), diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 6ff4bee264a7b..e044e7727b127 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -562,6 +562,68 @@ def add_arguments_to_parser(self, parser): assert cli.model.layer.out_features == 4 +def test_adapt_checkpoint_hparams_hook(cleandir): + """Test that the adapt_checkpoint_hparams hook is called and modifications are applied.""" + + class AdaptHparamsCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) + + def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams): + """Remove out_dim and hidden_dim for non-fit subcommands.""" + if subcommand != "fit": + checkpoint_hparams.pop("out_dim", None) + checkpoint_hparams.pop("hidden_dim", None) + return checkpoint_hparams + + # First, create a checkpoint by running fit + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsCLI(BoringCkptPathModel) + + assert cli.config.fit.model.out_dim == 3 + assert cli.config.fit.model.hidden_dim == 6 + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + + # Test that predict uses adapted hparams (without out_dim and hidden_dim) + cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsCLI(BoringCkptPathModel) + + # Since we removed out_dim and hidden_dim for predict, the CLI values should be used + assert cli.config.predict.model.out_dim == 5 + assert cli.config.predict.model.hidden_dim == 10 + + +def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir): + """Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading.""" + + class AdaptHparamsEmptyCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) + + def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams): + """Disable checkpoint hyperparameter loading.""" + return {} + + # First, create a checkpoint + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + + # Test that predict uses default values when hook returns empty dict + cli_args = ["predict", f"--ckpt_path={checkpoint_path}"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) + + # Model should use default values (out_dim=8, hidden_dim=16) + assert cli.config_init.predict.model.out_dim == 8 + assert cli.config_init.predict.model.hidden_dim == 16 + + def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):