diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 863a3a4a7e939..1bba5e4ca0da7 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186)) +- Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246)) + + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 91247127f6c87..ed065fdd12797 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -564,6 +564,11 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return + if "_class_path" in hparams: + hparams = { + "class_path": hparams.pop("_class_path"), + "dict_kwargs": hparams, + } hparams = {self.config.subcommand: {"model": hparams}} try: self.config = self.parser.parse_object(hparams, self.config) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 95124854a8fd5..e79f1b78e02da 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -491,6 +491,7 @@ class BoringCkptPathModel(BoringModel): def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None: super().__init__() self.save_hyperparameters() + self.hidden_dim = hidden_dim self.layer = torch.nn.Linear(32, out_dim) @@ -526,6 +527,41 @@ def add_arguments_to_parser(self, parser): assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue() +class BoringCkptPathSubclass(BoringCkptPathModel): + def __init__(self, extra: bool = True, **kwargs) -> None: + super().__init__(**kwargs) + self.extra = extra + + +def test_lightning_cli_ckpt_path_argument_hparams_subclass_mode(cleandir): + class CkptPathCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("model.init_args.out_dim", "model.init_args.hidden_dim", compute_fn=lambda x: x * 2) + + cli_args = ["fit", "--model=BoringCkptPathSubclass", "--model.out_dim=4", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True) + + assert cli.config.fit.model.class_path.endswith(".BoringCkptPathSubclass") + assert cli.config.fit.model.init_args == Namespace(out_dim=4, hidden_dim=8, extra=True) + hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" + assert hparams_path.is_file() + hparams = yaml.safe_load(hparams_path.read_text()) + assert hparams["out_dim"] == 4 + assert hparams["hidden_dim"] == 8 + assert hparams["extra"] is True + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + cli_args = ["predict", "--model=BoringCkptPathModel", f"--ckpt_path={checkpoint_path}"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True) + + assert isinstance(cli.model, BoringCkptPathSubclass) + assert cli.model.hidden_dim == 8 + assert cli.model.extra is True + assert cli.model.layer.out_features == 4 + + def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):