@@ -560,6 +560,7 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[
560560class BoringModelRequiredClasses (BoringModel ):
561561 def __init__ (self , num_classes : int , batch_size : int = 8 ):
562562 super ().__init__ ()
563+ self .save_hyperparameters ()
563564 self .num_classes = num_classes
564565 self .batch_size = batch_size
565566
@@ -577,29 +578,97 @@ def add_arguments_to_parser(self, parser):
577578 parser .link_arguments ("data.batch_size" , "model.batch_size" )
578579 parser .link_arguments ("data.num_classes" , "model.num_classes" , apply_on = "instantiate" )
579580
580- cli_args = ["--data.batch_size=12" ]
581+ cli_args = ["--data.batch_size=12" , "--trainer.max_epochs=1" ]
581582
582583 with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
583584 cli = MyLightningCLI (BoringModelRequiredClasses , BoringDataModuleBatchSizeAndClasses , run = False )
584585
585586 assert cli .model .batch_size == 12
586587 assert cli .model .num_classes == 5
587588
588- class MyLightningCLI (LightningCLI ):
589+ cli .trainer .fit (cli .model )
590+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
591+ assert hparams_path .is_file ()
592+ hparams = yaml .safe_load (hparams_path .read_text ())
593+
594+ hparams .pop ("_instantiator" )
595+ assert hparams == {"batch_size" : 12 , "num_classes" : 5 }
596+
597+ class MyLightningCLI2 (LightningCLI ):
589598 def add_arguments_to_parser (self , parser ):
590599 parser .link_arguments ("data.batch_size" , "model.init_args.batch_size" )
591600 parser .link_arguments ("data.num_classes" , "model.init_args.num_classes" , apply_on = "instantiate" )
592601
593- cli_args [- 1 ] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
602+ cli_args [0 ] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
594603
595604 with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
596- cli = MyLightningCLI (
605+ cli = MyLightningCLI2 (
597606 BoringModelRequiredClasses , BoringDataModuleBatchSizeAndClasses , subclass_mode_model = True , run = False
598607 )
599608
600609 assert cli .model .batch_size == 8
601610 assert cli .model .num_classes == 5
602611
612+ cli .trainer .fit (cli .model )
613+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
614+ assert hparams_path .is_file ()
615+ hparams = yaml .safe_load (hparams_path .read_text ())
616+
617+ hparams .pop ("_instantiator" )
618+ assert hparams == {"batch_size" : 8 , "num_classes" : 5 }
619+
620+
621+ class CustomAdam (torch .optim .Adam ):
622+ def __init__ (self , params , num_classes : Optional [int ] = None , ** kwargs ):
623+ super ().__init__ (params , ** kwargs )
624+
625+
626+ class DeepLinkTargetModel (BoringModel ):
627+ def __init__ (
628+ self ,
629+ optimizer : OptimizerCallable = torch .optim .Adam ,
630+ ):
631+ super ().__init__ ()
632+ self .save_hyperparameters ()
633+ self .optimizer = optimizer
634+
635+ def configure_optimizers (self ):
636+ optimizer = self .optimizer (self .parameters ())
637+ return {"optimizer" : optimizer }
638+
639+
640+ def test_lightning_cli_link_arguments_subcommands_nested_target (cleandir ):
641+ class MyLightningCLI (LightningCLI ):
642+ def add_arguments_to_parser (self , parser ):
643+ parser .link_arguments (
644+ "data.num_classes" ,
645+ "model.init_args.optimizer.init_args.num_classes" ,
646+ apply_on = "instantiate" ,
647+ )
648+
649+ cli_args = [
650+ "fit" ,
651+ "--data.batch_size=12" ,
652+ "--trainer.max_epochs=1" ,
653+ "--model=tests_pytorch.test_cli.DeepLinkTargetModel" ,
654+ "--model.optimizer=tests_pytorch.test_cli.CustomAdam" ,
655+ ]
656+
657+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
658+ cli = MyLightningCLI (
659+ DeepLinkTargetModel ,
660+ BoringDataModuleBatchSizeAndClasses ,
661+ subclass_mode_model = True ,
662+ auto_configure_optimizers = False ,
663+ )
664+
665+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
666+ assert hparams_path .is_file ()
667+ hparams = yaml .safe_load (hparams_path .read_text ())
668+
669+ assert hparams ["optimizer" ]["class_path" ] == "tests_pytorch.test_cli.CustomAdam"
670+ assert hparams ["optimizer" ]["init_args" ]["num_classes" ] == 5
671+
603672
604673class EarlyExitTestModel (BoringModel ):
605674 def on_fit_start (self ):
0 commit comments