Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,41 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
"""Adapt checkpoint hyperparameters before instantiating the model class.

This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.

Args:
subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict').
This allows you to apply different hyperparameter adaptations depending on the context.
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.

Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.

Example::

class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(
self, subcommand: str, checkpoint_hparams: dict[str, Any]
) -> dict[str, Any]:
# Only remove training-specific hyperparameters for non-fit subcommands
if subcommand != "fit":
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams

Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.

"""
return checkpoint_hparams

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
Expand All @@ -571,6 +606,12 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return

# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams)
if not hparams:
return

if "_class_path" in hparams:
hparams = {
"class_path": hparams.pop("_class_path"),
Expand Down
62 changes: 62 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,68 @@ def add_arguments_to_parser(self, parser):
assert cli.model.layer.out_features == 4


def test_adapt_checkpoint_hparams_hook(cleandir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_adapt_checkpoint_hparams_hook(cleandir):
def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir):

"""Test that the adapt_checkpoint_hparams hook is called and modifications are applied."""

class AdaptHparamsCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Comment on lines +569 to +571
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.

def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams):
"""Remove out_dim and hidden_dim for non-fit subcommands."""
if subcommand != "fit":
checkpoint_hparams.pop("out_dim", None)
checkpoint_hparams.pop("hidden_dim", None)
return checkpoint_hparams

# First, create a checkpoint by running fit
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(BoringCkptPathModel)

assert cli.config.fit.model.out_dim == 3
assert cli.config.fit.model.hidden_dim == 6

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses adapted hparams (without out_dim and hidden_dim)
cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(BoringCkptPathModel)

# Since we removed out_dim and hidden_dim for predict, the CLI values should be used
assert cli.config.predict.model.out_dim == 5
assert cli.config.predict.model.hidden_dim == 10


def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir):
"""Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading."""

class AdaptHparamsEmptyCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Comment on lines +603 to +605
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.

def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams):
"""Disable checkpoint hyperparameter loading."""
return {}

# First, create a checkpoint
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(BoringCkptPathModel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test fails because of BoringCkptPathModel has a module torch.nn.Linear(32, out_dim). If the out_dim is changed, then there is a tensor size mismatch.

Instead of using BoringCkptPathModel, implement a new class for these two tests, that just sets an attribute that can be asserted after instantiation.


checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses default values when hook returns empty dict
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(BoringCkptPathModel)

# Model should use default values (out_dim=8, hidden_dim=16)
assert cli.config_init.predict.model.out_dim == 8
assert cli.config_init.predict.model.hidden_dim == 16


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down