Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))


- Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246))


---

## [2.5.5] - 2025-09-05
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,11 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return
if "_class_path" in hparams:
hparams = {
"class_path": hparams.pop("_class_path"),
"dict_kwargs": hparams,
}
hparams = {self.config.subcommand: {"model": hparams}}
try:
self.config = self.parser.parse_object(hparams, self.config)
Expand Down
36 changes: 36 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class BoringCkptPathModel(BoringModel):
def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
super().__init__()
self.save_hyperparameters()
self.hidden_dim = hidden_dim
self.layer = torch.nn.Linear(32, out_dim)


Expand Down Expand Up @@ -526,6 +527,41 @@ def add_arguments_to_parser(self, parser):
assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue()


class BoringCkptPathSubclass(BoringCkptPathModel):
def __init__(self, extra: bool = True, **kwargs) -> None:
super().__init__(**kwargs)
self.extra = extra


def test_lightning_cli_ckpt_path_argument_hparams_subclass_mode(cleandir):
class CkptPathCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.init_args.out_dim", "model.init_args.hidden_dim", compute_fn=lambda x: x * 2)

cli_args = ["fit", "--model=BoringCkptPathSubclass", "--model.out_dim=4", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True)

assert cli.config.fit.model.class_path.endswith(".BoringCkptPathSubclass")
assert cli.config.fit.model.init_args == Namespace(out_dim=4, hidden_dim=8, extra=True)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
assert hparams["out_dim"] == 4
assert hparams["hidden_dim"] == 8
assert hparams["extra"] is True

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
cli_args = ["predict", "--model=BoringCkptPathModel", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True)

assert isinstance(cli.model, BoringCkptPathSubclass)
assert cli.model.hidden_dim == 8
assert cli.model.extra is True
assert cli.model.layer.out_features == 4


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down
Loading