diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 26af335f7be93..9f970f0c22a42 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -138,7 +138,10 @@ def add_lightning_class_args( if issubclass(lightning_class, Callback): self.callback_keys.append(nested_key) if subclass_mode: + print("Subclass mode") + print(nested_key) return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) + print(nested_key) return self.add_class_arguments( lightning_class, nested_key, diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 56b58d4d157a1..0a3b885211f78 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -597,6 +597,53 @@ def add_arguments_to_parser(self, parser): assert cli.model.num_classes == 5 +# Notes: +# - if variable is class attribute, it will only be present after instantiation -> apply_on="instantiate" +# - if you link argumetns and then try to pass them additionally, it will raise an error +# - if you initilize the model from the cmd instead of passing it as an instance to CLI init_args needs to be +# added additionally +# - if you initilize the model from the cmd instead of passing it as an instance to CLI apply on instantiate needs to +# be always true for linked arguments + + +def test_lightning_cli_link_arguments_init(): + # Will not work without init_args ("--data.init_args.batch_size=12") + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("data.batch_size", "model.batch_size") + parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate") + + cli_args = [ + "--data.batch_size=12", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False) + + assert cli.model.batch_size == 12 + assert cli.datamodule.batch_size == 12 + assert cli.model.num_classes == 5 + assert cli.datamodule.num_classes == 5 + + # Will work without init_args ("--data.batch_size=12") + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("data.batch_size", "model.init_args.batch_size", apply_on="instantiate") + parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate") + + cli_args = [ + "--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses", + "--model=tests_pytorch.test_cli.BoringModelRequiredClasses", + "--data.batch_size=12", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(run=False) + + assert cli.datamodule.batch_size == 12 + assert cli.model.batch_size == 12 + + class EarlyExitTestModel(BoringModel): def on_fit_start(self): raise MisconfigurationException("Error on fit start")