Skip to content

Commit 998ea3e

Browse files
committed
refactor(cli): Add subcommand parameter to adapt_checkpoint_hparams hook and add tests
- Update adapt_checkpoint_hparams signature to include subcommand parameter allowing context-aware customization of checkpoint hyperparameters - Change type annotations to use lowercase dict (Python 3.9+ style) - Update docstring with subcommand parameter documentation - Add example showing conditional logic based on subcommand - Add comprehensive unit tests: - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading - Tests cover both regular and subclass modes
1 parent ad1a028 commit 998ea3e

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

src/lightning/pytorch/cli.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
560560
else:
561561
self.config = parser.parse_args(args)
562562

563-
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
563+
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
564564
"""Adapt checkpoint hyperparameters before instantiating the model class.
565565
566566
This method allows for customization of hyperparameters loaded from a checkpoint when
@@ -569,6 +569,8 @@ def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[s
569569
``__init__`` parameters, you can remove or modify incompatible hyperparameters.
570570
571571
Args:
572+
subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict').
573+
This allows you to apply different hyperparameter adaptations depending on the context.
572574
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.
573575
574576
Returns:
@@ -577,10 +579,11 @@ def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[s
577579
Example::
578580
579581
class MyCLI(LightningCLI):
580-
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
581-
# Remove training-specific hyperparameters not needed for inference
582-
checkpoint_hparams.pop("lr", None)
583-
checkpoint_hparams.pop("weight_decay", None)
582+
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
583+
# Only remove training-specific hyperparameters for non-fit subcommands
584+
if subcommand != "fit":
585+
checkpoint_hparams.pop("lr", None)
586+
checkpoint_hparams.pop("weight_decay", None)
584587
return checkpoint_hparams
585588
586589
Note:
@@ -603,7 +606,7 @@ def _parse_ckpt_path(self) -> None:
603606
return
604607

605608
# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
606-
hparams = self.adapt_checkpoint_hparams(hparams)
609+
hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams)
607610
if not hparams:
608611
return
609612

tests/tests_pytorch/test_cli.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,66 @@ def add_arguments_to_parser(self, parser):
562562
assert cli.model.layer.out_features == 4
563563

564564

565+
def test_adapt_checkpoint_hparams_hook(cleandir):
566+
"""Test that the adapt_checkpoint_hparams hook is called and modifications are applied."""
567+
class AdaptHparamsCLI(LightningCLI):
568+
def add_arguments_to_parser(self, parser):
569+
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)
570+
571+
def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams):
572+
"""Remove out_dim and hidden_dim for non-fit subcommands."""
573+
if subcommand != "fit":
574+
checkpoint_hparams.pop("out_dim", None)
575+
checkpoint_hparams.pop("hidden_dim", None)
576+
return checkpoint_hparams
577+
578+
# First, create a checkpoint by running fit
579+
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
580+
with mock.patch("sys.argv", ["any.py"] + cli_args):
581+
cli = AdaptHparamsCLI(BoringCkptPathModel)
582+
583+
assert cli.config.fit.model.out_dim == 3
584+
assert cli.config.fit.model.hidden_dim == 6
585+
586+
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
587+
588+
# Test that predict uses adapted hparams (without out_dim and hidden_dim)
589+
cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"]
590+
with mock.patch("sys.argv", ["any.py"] + cli_args):
591+
cli = AdaptHparamsCLI(BoringCkptPathModel)
592+
593+
# Since we removed out_dim and hidden_dim for predict, the CLI values should be used
594+
assert cli.config.predict.model.out_dim == 5
595+
assert cli.config.predict.model.hidden_dim == 10
596+
597+
598+
def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir):
599+
"""Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading."""
600+
class AdaptHparamsEmptyCLI(LightningCLI):
601+
def add_arguments_to_parser(self, parser):
602+
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)
603+
604+
def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams):
605+
"""Disable checkpoint hyperparameter loading."""
606+
return {}
607+
608+
# First, create a checkpoint
609+
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
610+
with mock.patch("sys.argv", ["any.py"] + cli_args):
611+
cli = AdaptHparamsEmptyCLI(BoringCkptPathModel)
612+
613+
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
614+
615+
# Test that predict uses default values when hook returns empty dict
616+
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
617+
with mock.patch("sys.argv", ["any.py"] + cli_args):
618+
cli = AdaptHparamsEmptyCLI(BoringCkptPathModel)
619+
620+
# Model should use default values (out_dim=8, hidden_dim=16)
621+
assert cli.config_init.predict.model.out_dim == 8
622+
assert cli.config_init.predict.model.hidden_dim == 16
623+
624+
565625
def test_lightning_cli_submodules(cleandir):
566626
class MainModule(BoringModel):
567627
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):

0 commit comments

Comments
 (0)