-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
910a712
ad1a028
998ea3e
00e7032
b3b1025
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -562,6 +562,68 @@ def add_arguments_to_parser(self, parser): | |||||
| assert cli.model.layer.out_features == 4 | ||||||
|
|
||||||
|
|
||||||
| def test_adapt_checkpoint_hparams_hook(cleandir): | ||||||
| """Test that the adapt_checkpoint_hparams hook is called and modifications are applied.""" | ||||||
|
|
||||||
| class AdaptHparamsCLI(LightningCLI): | ||||||
| def add_arguments_to_parser(self, parser): | ||||||
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) | ||||||
|
|
||||||
|
Comment on lines
+569
to
+571
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction. |
||||||
| def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams): | ||||||
| """Remove out_dim and hidden_dim for non-fit subcommands.""" | ||||||
| if subcommand != "fit": | ||||||
| checkpoint_hparams.pop("out_dim", None) | ||||||
| checkpoint_hparams.pop("hidden_dim", None) | ||||||
| return checkpoint_hparams | ||||||
|
|
||||||
| # First, create a checkpoint by running fit | ||||||
| cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] | ||||||
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||||||
| cli = AdaptHparamsCLI(BoringCkptPathModel) | ||||||
|
|
||||||
| assert cli.config.fit.model.out_dim == 3 | ||||||
| assert cli.config.fit.model.hidden_dim == 6 | ||||||
|
|
||||||
| checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) | ||||||
|
|
||||||
| # Test that predict uses adapted hparams (without out_dim and hidden_dim) | ||||||
| cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"] | ||||||
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||||||
| cli = AdaptHparamsCLI(BoringCkptPathModel) | ||||||
|
|
||||||
| # Since we removed out_dim and hidden_dim for predict, the CLI values should be used | ||||||
| assert cli.config.predict.model.out_dim == 5 | ||||||
| assert cli.config.predict.model.hidden_dim == 10 | ||||||
|
|
||||||
|
|
||||||
| def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir): | ||||||
| """Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading.""" | ||||||
|
|
||||||
| class AdaptHparamsEmptyCLI(LightningCLI): | ||||||
| def add_arguments_to_parser(self, parser): | ||||||
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) | ||||||
|
|
||||||
|
Comment on lines
+603
to
+605
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction. |
||||||
| def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams): | ||||||
| """Disable checkpoint hyperparameter loading.""" | ||||||
| return {} | ||||||
|
|
||||||
| # First, create a checkpoint | ||||||
| cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] | ||||||
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||||||
| cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test fails because of Instead of using |
||||||
|
|
||||||
| checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) | ||||||
|
|
||||||
| # Test that predict uses default values when hook returns empty dict | ||||||
| cli_args = ["predict", f"--ckpt_path={checkpoint_path}"] | ||||||
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||||||
| cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) | ||||||
|
|
||||||
| # Model should use default values (out_dim=8, hidden_dim=16) | ||||||
| assert cli.config_init.predict.model.out_dim == 8 | ||||||
| assert cli.config_init.predict.model.hidden_dim == 16 | ||||||
|
|
||||||
|
|
||||||
| def test_lightning_cli_submodules(cleandir): | ||||||
| class MainModule(BoringModel): | ||||||
| def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.