@@ -597,42 +597,47 @@ def add_arguments_to_parser(self, parser):
597597 assert cli .model .num_classes == 5
598598
599599
600+ # Notes:
601+ # - if variable is class attribute, it will only be present after instantiation -> apply_on="instantiate"
602+ # If you link argumetns and then try to pass them additionally, it will raise an error
603+
604+
600605def test_lightning_cli_link_arguments_init ():
601606 # Will not work without init_args ("--data.init_args.batch_size=12")
602607 class MyLightningCLI (LightningCLI ):
603608 def add_arguments_to_parser (self , parser ):
604- parser .link_arguments ("data.batch_size" , "model.init_args.batch_size" )
609+ parser .link_arguments ("data.batch_size" , "model.batch_size" )
610+ parser .link_arguments ("data.num_classes" , "model.num_classes" , apply_on = "instantiate" )
605611
606612 cli_args = [
607- "--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses" ,
608- "--model=tests_pytorch.test_cli.BoringModelRequiredClasses" ,
609- "--data.init_args.batch_size=12" ,
610- "--model.init_args.num_classes=5" ,
613+ "--data.batch_size=12" ,
611614 ]
612615
613616 with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
614- cli = MyLightningCLI (run = False )
617+ cli = MyLightningCLI (BoringModelRequiredClasses , BoringDataModuleBatchSizeAndClasses , run = False )
615618
619+ assert cli .model .batch_size == 12
616620 assert cli .datamodule .batch_size == 12
621+ assert cli .model .num_classes == 5
622+ assert cli .datamodule .num_classes == 5
617623
618624 # Will work without init_args ("--data.batch_size=12")
619625 class MyLightningCLI (LightningCLI ):
620626 def add_arguments_to_parser (self , parser ):
621- pass
627+ parser .link_arguments ("data.batch_size" , "model.batch_size" )
628+ parser .link_arguments ("data.num_classes" , "model.num_classes" , apply_on = "instantiate" )
622629
623630 cli_args = [
624631 "--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses" ,
625632 "--model=tests_pytorch.test_cli.BoringModelRequiredClasses" ,
626633 "--data.batch_size=12" ,
627- "--model.num_classes=12" ,
628634 ]
629635
630636 with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
631637 cli = MyLightningCLI (run = False )
632638
633- print (cli .config )
634-
635639 assert cli .datamodule .batch_size == 12
640+ assert cli .model .batch_size == 12
636641
637642
638643class EarlyExitTestModel (BoringModel ):
0 commit comments