Skip to content

Commit 735b62f

Browse files
Implemented test for cli
It shows the inconsitency that is required to cli arguments depending on how you set up the model and data module
1 parent 84053be commit 735b62f

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

tests/tests_pytorch/test_cli.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -597,42 +597,47 @@ def add_arguments_to_parser(self, parser):
597597
assert cli.model.num_classes == 5
598598

599599

600+
# Notes:
601+
# - if variable is class attribute, it will only be present after instantiation -> apply_on="instantiate"
602+
# If you link argumetns and then try to pass them additionally, it will raise an error
603+
604+
600605
def test_lightning_cli_link_arguments_init():
601606
# Will not work without init_args ("--data.init_args.batch_size=12")
602607
class MyLightningCLI(LightningCLI):
603608
def add_arguments_to_parser(self, parser):
604-
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
609+
parser.link_arguments("data.batch_size", "model.batch_size")
610+
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
605611

606612
cli_args = [
607-
"--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses",
608-
"--model=tests_pytorch.test_cli.BoringModelRequiredClasses",
609-
"--data.init_args.batch_size=12",
610-
"--model.init_args.num_classes=5",
613+
"--data.batch_size=12",
611614
]
612615

613616
with mock.patch("sys.argv", ["any.py"] + cli_args):
614-
cli = MyLightningCLI(run=False)
617+
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)
615618

619+
assert cli.model.batch_size == 12
616620
assert cli.datamodule.batch_size == 12
621+
assert cli.model.num_classes == 5
622+
assert cli.datamodule.num_classes == 5
617623

618624
# Will work without init_args ("--data.batch_size=12")
619625
class MyLightningCLI(LightningCLI):
620626
def add_arguments_to_parser(self, parser):
621-
pass
627+
parser.link_arguments("data.batch_size", "model.batch_size")
628+
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
622629

623630
cli_args = [
624631
"--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses",
625632
"--model=tests_pytorch.test_cli.BoringModelRequiredClasses",
626633
"--data.batch_size=12",
627-
"--model.num_classes=12",
628634
]
629635

630636
with mock.patch("sys.argv", ["any.py"] + cli_args):
631637
cli = MyLightningCLI(run=False)
632638

633-
print(cli.config)
634-
635639
assert cli.datamodule.batch_size == 12
640+
assert cli.model.batch_size == 12
636641

637642

638643
class EarlyExitTestModel(BoringModel):

0 commit comments

Comments
 (0)