Skip to content

Fix: Respect required=False in add_lightning_class_args when subclass_mode=False #20856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def add_lightning_class_args(
Args:
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
nested_key: Name of the nested namespace to store arguments.
subclass_mode: Whether allow any subclass of the given class.
subclass_mode: Whether to allow any subclass of the given class.
required: Whether the argument group is required.

Returns:
Expand All @@ -145,15 +145,26 @@ def add_lightning_class_args(
):
if issubclass(lightning_class, Callback):
self.callback_keys.append(nested_key)

# NEW LOGIC: If subclass_mode=False and required=False, only add if config provides this key
if not subclass_mode and not required:
config_path = f"{self.subcommand}.{nested_key}" if getattr(self, "subcommand", None) else nested_key
config = getattr(self, "config", {})
if not any(k.startswith(config_path) for k in config):
# Skip adding class arguments
return []
Comment on lines +149 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is only called on parser construction, not during parsing. So at this point there isn't any config. Because of the getattr(self, "config", {}) this doesn't fail but it will never do anything.


if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)

return self.add_class_arguments(
lightning_class,
nested_key,
fail_untyped=False,
instantiate=not issubclass(lightning_class, Trainer),
sub_configs=True,
)

raise MisconfigurationException(
f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: "
"Trainer, LightningModule, LightningDataModule, or Callback."
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,6 +1860,34 @@ def test_lightning_cli_args_and_sys_argv_warning():
LightningCLI(TestModel, run=False, args=["--model.foo=789"])


def test_add_class_args_required_false_skips_addition(tmp_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KAVYANSHTYAGI you did not run the test to check whether the code was working. I know this because it is not possible that it works. When you open a pull request it would be good that beforehand you run the tests. And the test should fail without the code changes and succeed with them.

from lightning.pytorch import callbacks, cli

class FooCheckpoint(callbacks.ModelCheckpoint):
def __init__(self, dirpath, *args, **kwargs):
super().__init__(dirpath, *args, **kwargs)
Comment on lines +1866 to +1868
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FooCheckpoint(callbacks.ModelCheckpoint):
def __init__(self, dirpath, *args, **kwargs):
super().__init__(dirpath, *args, **kwargs)

There is no reason for this class since ModelCheckpoint could be used directly.


class SimpleModel:
def __init__(self):
pass

class SimpleDataModule:
def __init__(self):
pass
Comment on lines +1870 to +1876
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class SimpleModel:
def __init__(self):
pass
class SimpleDataModule:
def __init__(self):
pass

LightningCLI requires modules to inherit from the lightning classes. These don't so there is no way this would work. Also there is no need to implement new classes, that is what BoringModel and BoringDataModule are for.


class FooCLI(cli.LightningCLI):
def __init__(self):
super().__init__(
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None
)
Comment on lines +1879 to +1882
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self):
super().__init__(
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None
)

What is the point of this? It doesn't seem relevant to the test.


def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False)
parser.add_lightning_class_args(ModelCheckpoint, "checkpoint", required=False)


# Expectation: No error raised even though FooCheckpoint requires `dirpath`
FooCLI()
Comment on lines +1887 to +1888
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Expectation: No error raised even though FooCheckpoint requires `dirpath`
FooCLI()
FooCLI(BoringModel, BoringDataModule, run=False, save_config_callback=None, args=[])

This must have at least args=[], otherwise it gets the arguments from the pytest call which for sure will not be valid for the CLI.

But anyway, callback + required=False should not be supported. But this test could be changed to assert that a specific error message is shown.



def test_lightning_cli_jsonnet(cleandir):
class MainModule(BoringModel):
def __init__(self, main_param: int = 1):
Expand Down
Loading