@@ -597,6 +597,44 @@ def add_arguments_to_parser(self, parser):
597597 assert cli .model .num_classes == 5
598598
599599
600+ def test_lightning_cli_link_arguments_init ():
601+ # Will not work without init_args ("--data.init_args.batch_size=12")
602+ class MyLightningCLI (LightningCLI ):
603+ def add_arguments_to_parser (self , parser ):
604+ parser .link_arguments ("data.batch_size" , "model.init_args.batch_size" )
605+
606+ cli_args = [
607+ "--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses" ,
608+ "--model=tests_pytorch.test_cli.BoringModelRequiredClasses" ,
609+ "--data.init_args.batch_size=12" ,
610+ "--model.init_args.num_classes=5" ,
611+ ]
612+
613+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
614+ cli = MyLightningCLI (run = False )
615+
616+ assert cli .datamodule .batch_size == 12
617+
618+ # Will work without init_args ("--data.batch_size=12")
619+ class MyLightningCLI (LightningCLI ):
620+ def add_arguments_to_parser (self , parser ):
621+ pass
622+
623+ cli_args = [
624+ "--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses" ,
625+ "--model=tests_pytorch.test_cli.BoringModelRequiredClasses" ,
626+ "--data.batch_size=12" ,
627+ "--model.num_classes=12" ,
628+ ]
629+
630+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
631+ cli = MyLightningCLI (run = False )
632+
633+ print (cli .config )
634+
635+ assert cli .datamodule .batch_size == 12
636+
637+
600638class EarlyExitTestModel (BoringModel ):
601639 def on_fit_start (self ):
602640 raise MisconfigurationException ("Error on fit start" )
0 commit comments