Skip to content

Commit d2ba52b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4bab579 commit d2ba52b

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

src/lightning/pytorch/cli.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,50 +126,50 @@ def add_lightning_class_args(
126126
required: bool = True,
127127
) -> list[str]:
128128
"""Adds arguments from a lightning class to a nested key of the parser.
129-
129+
130130
Args:
131131
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
132132
nested_key: Name of the nested namespace to store arguments.
133133
subclass_mode: Whether to allow any subclass of the given class.
134134
required: Whether the argument group is required.
135-
135+
136136
Returns:
137137
A list with the names of the class arguments added.
138+
138139
"""
139140
if callable(lightning_class) and not isinstance(lightning_class, type):
140141
lightning_class = class_from_function(lightning_class)
141-
142+
142143
if isinstance(lightning_class, type) and issubclass(
143144
lightning_class, (Trainer, LightningModule, LightningDataModule, Callback)
144145
):
145146
if issubclass(lightning_class, Callback):
146147
self.callback_keys.append(nested_key)
147-
148+
148149
# NEW LOGIC: If subclass_mode=False and required=False, only add if config provides this key
149150
if not subclass_mode and not required:
150151
config_path = f"{self.subcommand}.{nested_key}" if getattr(self, "subcommand", None) else nested_key
151152
config = getattr(self, "config", {})
152-
if not any(k.startswith(config_path) for k in config.keys()):
153+
if not any(k.startswith(config_path) for k in config):
153154
# Skip adding class arguments
154155
return []
155-
156+
156157
if subclass_mode:
157158
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)
158-
159+
159160
return self.add_class_arguments(
160161
lightning_class,
161162
nested_key,
162163
fail_untyped=False,
163164
instantiate=not issubclass(lightning_class, Trainer),
164165
sub_configs=True,
165166
)
166-
167+
167168
raise MisconfigurationException(
168169
f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: "
169170
"Trainer, LightningModule, LightningDataModule, or Callback."
170171
)
171172

172-
173173
def add_optimizer_args(
174174
self,
175175
optimizer_class: Union[type[Optimizer], tuple[type[Optimizer], ...]] = (Optimizer,),

tests/tests_pytorch/test_cli.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,30 +1792,28 @@ def test_lightning_cli_args_and_sys_argv_warning():
17921792

17931793

17941794
def test_add_class_args_required_false_skips_addition(tmp_path):
1795-
from lightning.pytorch import cli, callbacks
1795+
from lightning.pytorch import callbacks, cli
17961796

17971797
class FooCheckpoint(callbacks.ModelCheckpoint):
17981798
def __init__(self, dirpath, *args, **kwargs):
17991799
super().__init__(dirpath, *args, **kwargs)
18001800

18011801
class SimpleModel:
1802-
def __init__(self): pass
1802+
def __init__(self):
1803+
pass
18031804

18041805
class SimpleDataModule:
1805-
def __init__(self): pass
1806+
def __init__(self):
1807+
pass
18061808

18071809
class FooCLI(cli.LightningCLI):
18081810
def __init__(self):
18091811
super().__init__(
1810-
model_class=SimpleModel,
1811-
datamodule_class=SimpleDataModule,
1812-
run=False,
1813-
save_config_callback=None
1812+
model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None
18141813
)
18151814

18161815
def add_arguments_to_parser(self, parser):
18171816
parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False)
18181817

18191818
# Expectation: No error raised even though FooCheckpoint requires `dirpath`
18201819
FooCLI()
1821-

0 commit comments

Comments
 (0)