Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ def add_lightning_class_args(
if issubclass(lightning_class, Callback):
self.callback_keys.append(nested_key)
if subclass_mode:
print("Subclass mode")
print(nested_key)
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)
print(nested_key)
return self.add_class_arguments(
lightning_class,
nested_key,
Expand Down
47 changes: 47 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,53 @@ def add_arguments_to_parser(self, parser):
assert cli.model.num_classes == 5


# Notes:
# - if variable is class attribute, it will only be present after instantiation -> apply_on="instantiate"
# - if you link argumetns and then try to pass them additionally, it will raise an error
# - if you initilize the model from the cmd instead of passing it as an instance to CLI init_args needs to be
# added additionally
# - if you initilize the model from the cmd instead of passing it as an instance to CLI apply on instantiate needs to
# be always true for linked arguments


def test_lightning_cli_link_arguments_init():
# Will not work without init_args ("--data.init_args.batch_size=12")
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.batch_size")
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")

cli_args = [
"--data.batch_size=12",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)

assert cli.model.batch_size == 12
assert cli.datamodule.batch_size == 12
assert cli.model.num_classes == 5
assert cli.datamodule.num_classes == 5

# Will work without init_args ("--data.batch_size=12")
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.init_args.batch_size", apply_on="instantiate")
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")

cli_args = [
"--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses",
"--model=tests_pytorch.test_cli.BoringModelRequiredClasses",
"--data.batch_size=12",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(run=False)

assert cli.datamodule.batch_size == 12
assert cli.model.batch_size == 12


class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
raise MisconfigurationException("Error on fit start")
Expand Down