Skip to content
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `RichProgressBar` crashing when sanity checking using val dataloader with 0 len ([#21108](https://github.com/Lightning-AI/pytorch-lightning/pull/21108))


- Fixed `LightningCLI` not using `ckpt_path` hyperparameters to instantiate classes ([#21116](https://github.com/Lightning-AI/pytorch-lightning/pull/21116))


---

## [2.5.3] - 2025-08-13
Expand Down
15 changes: 15 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
from collections.abc import Iterable
from functools import partial, update_wrapper
from pathlib import Path
from types import MethodType
from typing import Any, Callable, Optional, TypeVar, Union

Expand Down Expand Up @@ -397,6 +398,7 @@ def __init__(
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs)
self.setup_parser(run, main_kwargs, subparser_kwargs)
self.parse_arguments(self.parser, args)
self._parse_ckpt_path()

self.subcommand = self.config["subcommand"] if run else None

Expand Down Expand Up @@ -551,6 +553,19 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
return
ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
if ckpt_path and Path(ckpt_path).is_file():
ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu")
hparams = ckpt.get("hyper_parameters", {})
hparams.pop("_instantiator", None)
if hparams:
hparams = {self.config.subcommand: {"model": hparams}}
self.config = self.parser.parse_object(hparams, self.config)

def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
return
Expand Down
34 changes: 34 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,40 @@ def test_lightning_cli_print_config():
assert outval["ckpt_path"] is None


class BoringCkptPathModel(BoringModel):
def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Linear(32, out_dim)


def test_lightning_cli_ckpt_path_argument_hparams(cleandir):
class CkptPathCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = CkptPathCLI(BoringCkptPathModel)

assert cli.config.fit.model.out_dim == 3
assert cli.config.fit.model.hidden_dim == 6
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
assert hparams["out_dim"] == 3
assert hparams["hidden_dim"] == 6

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = CkptPathCLI(BoringCkptPathModel)

assert cli.config.predict.model.out_dim == 3
assert cli.config.predict.model.hidden_dim == 6
assert cli.config_init.predict.model.layer.out_features == 3


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down
Loading