Skip to content

Fix: Respect required=False in add_lightning_class_args when subclass_mode=False #20856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

KAVYANSHTYAGI
Copy link
Contributor

@KAVYANSHTYAGI KAVYANSHTYAGI commented May 25, 2025

What does this PR do?

This PR fixes an inconsistency in LightningCLI.add_lightning_class_args() where the required=False flag was being ignored when subclass_mode=False (which is the default).
According to the documentation, setting required=False should make providing the corresponding config optional — but in practice, argument registration was enforced regardless, leading to errors if required constructor parameters (e.g., dirpath in ModelCheckpoint) were not provided.

This change adds a conditional check that skips argument registration when:

    subclass_mode=False

    required=False

and the config key (nested_key) is not provided by the user

Additionally, a test case was added to verify that this behavior works as expected.

Fixes #20851 #20851

Before submitting

Was this discussed/agreed via a GitHub issue? yes

Did you read the contributor guideline, Pull Request section? yes

Did you make sure your PR does only one thing, instead of bundling different changes together? yes

Did you write any new necessary tests? yes

Did you verify new and existing tests pass locally with your changes? yes

Did you update the documentation if necessary?

Did you list all the breaking changes introduced by this pull request?

Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.

Yes :)

📚 Documentation preview 📚: https://pytorch-lightning--20856.org.readthedocs.build/en/20856/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label May 25, 2025
@Borda Borda changed the title Fix: Respect required=False in add_lightning_class_args when subclass_mode=False Fix: Respect required=False in add_lightning_class_args when subclass_mode=False Jun 18, 2025
Copy link

stale bot commented Jul 19, 2025

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://lightning.ai/docs/pytorch/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Discord. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jul 19, 2025
@stale stale bot removed the won't fix This will not be worked on label Aug 8, 2025
@Borda
Copy link
Member

Borda commented Aug 8, 2025

@mauvilsa mind have look, pls

Copy link
Contributor

@mauvilsa mauvilsa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda this should not be merged.

@KAVYANSHTYAGI unfortunately the presence of the required parameter mislead you. The only reason why add_lightning_class_args supports Callback is to add callbacks that must always be present, see configure-forced-callbacks in the docs. So simply Callback + required=False should give an error. But you are welcome to contribute that error logic instead.

Comment on lines +149 to +155
# NEW LOGIC: If subclass_mode=False and required=False, only add if config provides this key
if not subclass_mode and not required:
config_path = f"{self.subcommand}.{nested_key}" if getattr(self, "subcommand", None) else nested_key
config = getattr(self, "config", {})
if not any(k.startswith(config_path) for k in config):
# Skip adding class arguments
return []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is only called on parser construction, not during parsing. So at this point there isn't any config. Because of the getattr(self, "config", {}) this doesn't fail but it will never do anything.

@@ -1860,6 +1860,34 @@ def test_lightning_cli_args_and_sys_argv_warning():
LightningCLI(TestModel, run=False, args=["--model.foo=789"])


def test_add_class_args_required_false_skips_addition(tmp_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KAVYANSHTYAGI you did not run the test to check whether the code was working. I know this because it is not possible that it works. When you open a pull request it would be good that beforehand you run the tests. And the test should fail without the code changes and succeed with them.

Comment on lines +1866 to +1868
class FooCheckpoint(callbacks.ModelCheckpoint):
def __init__(self, dirpath, *args, **kwargs):
super().__init__(dirpath, *args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FooCheckpoint(callbacks.ModelCheckpoint):
def __init__(self, dirpath, *args, **kwargs):
super().__init__(dirpath, *args, **kwargs)

There is no reason for this class since ModelCheckpoint could be used directly.

Comment on lines +1870 to +1876
class SimpleModel:
def __init__(self):
pass

class SimpleDataModule:
def __init__(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class SimpleModel:
def __init__(self):
pass
class SimpleDataModule:
def __init__(self):
pass

LightningCLI requires modules to inherit from the lightning classes. These don't so there is no way this would work. Also there is no need to implement new classes, that is what BoringModel and BoringDataModule are for.

Comment on lines +1879 to +1882
def __init__(self):
super().__init__(
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self):
super().__init__(
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None
)

What is the point of this? It doesn't seem relevant to the test.

)

def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False)
parser.add_lightning_class_args(ModelCheckpoint, "checkpoint", required=False)

Comment on lines +1887 to +1888
# Expectation: No error raised even though FooCheckpoint requires `dirpath`
FooCLI()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Expectation: No error raised even though FooCheckpoint requires `dirpath`
FooCLI()
FooCLI(BoringModel, BoringDataModule, run=False, save_config_callback=None, args=[])

This must have at least args=[], otherwise it gets the arguments from the pytest call which for sure will not be valid for the CLI.

But anyway, callback + required=False should not be supported. But this test could be changed to assert that a specific error message is shown.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add_lightning_class_args required argument ignored if not using subclass mode
3 participants