-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix: Respect required=False
in add_lightning_class_args
when subclass_mode=False
#20856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix: Respect required=False
in add_lightning_class_args
when subclass_mode=False
#20856
Conversation
for more information, see https://pre-commit.ci
required=False
in add_lightning_class_args
when subclass_mode=False
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://lightning.ai/docs/pytorch/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Discord. Thank you for your contributions. |
@mauvilsa mind have look, pls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Borda this should not be merged.
@KAVYANSHTYAGI unfortunately the presence of the required
parameter mislead you. The only reason why add_lightning_class_args
supports Callback
is to add callbacks that must always be present, see configure-forced-callbacks in the docs. So simply Callback
+ required=False
should give an error. But you are welcome to contribute that error logic instead.
# NEW LOGIC: If subclass_mode=False and required=False, only add if config provides this key | ||
if not subclass_mode and not required: | ||
config_path = f"{self.subcommand}.{nested_key}" if getattr(self, "subcommand", None) else nested_key | ||
config = getattr(self, "config", {}) | ||
if not any(k.startswith(config_path) for k in config): | ||
# Skip adding class arguments | ||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is only called on parser construction, not during parsing. So at this point there isn't any config. Because of the getattr(self, "config", {})
this doesn't fail but it will never do anything.
@@ -1860,6 +1860,34 @@ def test_lightning_cli_args_and_sys_argv_warning(): | |||
LightningCLI(TestModel, run=False, args=["--model.foo=789"]) | |||
|
|||
|
|||
def test_add_class_args_required_false_skips_addition(tmp_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@KAVYANSHTYAGI you did not run the test to check whether the code was working. I know this because it is not possible that it works. When you open a pull request it would be good that beforehand you run the tests. And the test should fail without the code changes and succeed with them.
class FooCheckpoint(callbacks.ModelCheckpoint): | ||
def __init__(self, dirpath, *args, **kwargs): | ||
super().__init__(dirpath, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class FooCheckpoint(callbacks.ModelCheckpoint): | |
def __init__(self, dirpath, *args, **kwargs): | |
super().__init__(dirpath, *args, **kwargs) |
There is no reason for this class since ModelCheckpoint
could be used directly.
class SimpleModel: | ||
def __init__(self): | ||
pass | ||
|
||
class SimpleDataModule: | ||
def __init__(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class SimpleModel: | |
def __init__(self): | |
pass | |
class SimpleDataModule: | |
def __init__(self): | |
pass |
LightningCLI
requires modules to inherit from the lightning classes. These don't so there is no way this would work. Also there is no need to implement new classes, that is what BoringModel
and BoringDataModule
are for.
def __init__(self): | ||
super().__init__( | ||
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self): | |
super().__init__( | |
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None | |
) |
What is the point of this? It doesn't seem relevant to the test.
) | ||
|
||
def add_arguments_to_parser(self, parser): | ||
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False) | |
parser.add_lightning_class_args(ModelCheckpoint, "checkpoint", required=False) |
# Expectation: No error raised even though FooCheckpoint requires `dirpath` | ||
FooCLI() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Expectation: No error raised even though FooCheckpoint requires `dirpath` | |
FooCLI() | |
FooCLI(BoringModel, BoringDataModule, run=False, save_config_callback=None, args=[]) |
This must have at least args=[]
, otherwise it gets the arguments from the pytest call which for sure will not be valid for the CLI.
But anyway, callback
+ required=False
should not be supported. But this test could be changed to assert that a specific error message is shown.
What does this PR do?
This PR fixes an inconsistency in LightningCLI.add_lightning_class_args() where the required=False flag was being ignored when subclass_mode=False (which is the default).
According to the documentation, setting required=False should make providing the corresponding config optional — but in practice, argument registration was enforced regardless, leading to errors if required constructor parameters (e.g., dirpath in ModelCheckpoint) were not provided.
This change adds a conditional check that skips argument registration when:
and the config key (nested_key) is not provided by the user
Additionally, a test case was added to verify that this behavior works as expected.
Fixes #20851 #20851
Before submitting
Was this discussed/agreed via a GitHub issue? yes
Did you read the contributor guideline, Pull Request section? yes
Did you make sure your PR does only one thing, instead of bundling different changes together? yes
Did you write any new necessary tests? yes
Did you verify new and existing tests pass locally with your changes? yes
Did you update the documentation if necessary?
Did you list all the breaking changes introduced by this pull request?
Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)
Anyone in the community is welcome to review the PR.
Yes :)📚 Documentation preview 📚: https://pytorch-lightning--20856.org.readthedocs.build/en/20856/