Skip to content

Commit 84053be

Browse files
Added first test that shows mentioned deficites
1 parent 8ad3e29 commit 84053be

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/tests_pytorch/test_cli.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,44 @@ def add_arguments_to_parser(self, parser):
597597
assert cli.model.num_classes == 5
598598

599599

600+
def test_lightning_cli_link_arguments_init():
601+
# Will not work without init_args ("--data.init_args.batch_size=12")
602+
class MyLightningCLI(LightningCLI):
603+
def add_arguments_to_parser(self, parser):
604+
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
605+
606+
cli_args = [
607+
"--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses",
608+
"--model=tests_pytorch.test_cli.BoringModelRequiredClasses",
609+
"--data.init_args.batch_size=12",
610+
"--model.init_args.num_classes=5",
611+
]
612+
613+
with mock.patch("sys.argv", ["any.py"] + cli_args):
614+
cli = MyLightningCLI(run=False)
615+
616+
assert cli.datamodule.batch_size == 12
617+
618+
# Will work without init_args ("--data.batch_size=12")
619+
class MyLightningCLI(LightningCLI):
620+
def add_arguments_to_parser(self, parser):
621+
pass
622+
623+
cli_args = [
624+
"--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses",
625+
"--model=tests_pytorch.test_cli.BoringModelRequiredClasses",
626+
"--data.batch_size=12",
627+
"--model.num_classes=12",
628+
]
629+
630+
with mock.patch("sys.argv", ["any.py"] + cli_args):
631+
cli = MyLightningCLI(run=False)
632+
633+
print(cli.config)
634+
635+
assert cli.datamodule.batch_size == 12
636+
637+
600638
class EarlyExitTestModel(BoringModel):
601639
def on_fit_start(self):
602640
raise MisconfigurationException("Error on fit start")

0 commit comments

Comments
 (0)