Skip to content

Commit 79de6a9

Browse files
mauvilsacarmocca
andauthored
LightningCLI natively support callback list append (Lightning-AI#13129)
* LightningCLI natively support callback list append. * Update jsonargparse version * Handle case where callbacks is not a list. * Fix PEP8 issue. * Handle mypy false positive Co-authored-by: Carlos Mocholí <[email protected]>
1 parent c1f0502 commit 79de6a9

File tree

5 files changed

+30
-159
lines changed

5 files changed

+30
-159
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8888
- `LightningCLI`'s shorthand notation changed to use jsonargparse native feature ([#12614](https://github.com/PyTorchLightning/pytorch-lightning/pull/12614))
8989

9090

91+
- `LightningCLI` changed to use jsonargparse native support for list append ([#13129](https://github.com/PyTorchLightning/pytorch-lightning/pull/13129))
92+
93+
9194
- Changed `seed_everything_default` argument in the `LightningCLI` to type `Union[bool, int]`. If set to `True` a seed is automatically generated for the parser argument `--seed_everything`. ([#12822](https://github.com/PyTorchLightning/pytorch-lightning/pull/12822), [#13110](https://github.com/PyTorchLightning/pytorch-lightning/pull/13110))
9295

9396

docs/source/cli/lightning_cli_advanced_3.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ The argument's order matters and the user needs to pass the arguments in the fol
119119
.. code-block:: bash
120120
121121
$ python ... \
122-
--trainer.callbacks={CALLBACK_1_NAME} \
122+
--trainer.callbacks+={CALLBACK_1_NAME} \
123123
--trainer.callbacks.{CALLBACK_1_ARGS_1}=... \
124124
--trainer.callbacks.{CALLBACK_1_ARGS_2}=... \
125125
...
126-
--trainer.callbacks={CALLBACK_N_NAME} \
126+
--trainer.callbacks+={CALLBACK_N_NAME} \
127127
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
128128
...
129129
@@ -132,9 +132,9 @@ Here is an example:
132132
.. code-block:: bash
133133
134134
$ python ... \
135-
--trainer.callbacks=EarlyStopping \
135+
--trainer.callbacks+=EarlyStopping \
136136
--trainer.callbacks.patience=5 \
137-
--trainer.callbacks=LearningRateMonitor \
137+
--trainer.callbacks+=LearningRateMonitor \
138138
--trainer.callbacks.logging_interval=epoch
139139
140140
Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
@@ -154,12 +154,12 @@ as described above:
154154
155155
.. code-block:: bash
156156
157-
$ python ... --trainer.callbacks=CustomCallback ...
157+
$ python ... --trainer.callbacks+=CustomCallback ...
158158
159159
.. note::
160160

161-
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
162-
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
161+
This shorthand notation is also supported inside a configuration file. The configuration file
162+
generated by calling the previous command with ``--print_config`` will have the full ``class_path`` notation.
163163

164164
.. code-block:: yaml
165165

pytorch_lightning/utilities/cli.py

Lines changed: 6 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@
1515

1616
import inspect
1717
import os
18-
import sys
1918
from functools import partial, update_wrapper
2019
from types import MethodType, ModuleType
2120
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union
22-
from unittest import mock
2321

2422
import torch
25-
import yaml
2623
from torch.optim import Optimizer
2724

2825
import pytorch_lightning as pl
@@ -34,7 +31,7 @@
3431
from pytorch_lightning.utilities.model_helpers import is_overridden
3532
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
3633

37-
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.8.0")
34+
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.9.0")
3835

3936
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
4037
import docstring_parser
@@ -46,8 +43,6 @@
4643
register_unresolvable_import_paths,
4744
set_config_read_mode,
4845
)
49-
from jsonargparse.typehints import get_all_subclass_paths
50-
from jsonargparse.util import import_object
5146

5247
register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
5348
set_config_read_mode(fsspec_enabled=True)
@@ -262,73 +257,6 @@ def add_lr_scheduler_args(
262257
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
263258
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
264259

265-
def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
266-
argv = sys.argv
267-
nested_key = "trainer.callbacks"
268-
if any(arg.startswith(f"--{nested_key}") for arg in argv):
269-
classes = tuple(import_object(x) for x in get_all_subclass_paths(Callback))
270-
argv = self._convert_argv_issue_85(classes, nested_key, argv)
271-
with mock.patch("sys.argv", argv):
272-
return super().parse_args(*args, **kwargs)
273-
274-
@staticmethod
275-
def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
276-
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85.
277-
278-
Adds support for shorthand notation for ``List[object]`` arguments.
279-
"""
280-
passed_args, clean_argv = [], []
281-
passed_configs = {}
282-
argv_key = f"--{nested_key}"
283-
# get the argv args for this nested key
284-
i = 0
285-
while i < len(argv):
286-
arg = argv[i]
287-
if arg.startswith(argv_key):
288-
if "=" in arg:
289-
key, value = arg.split("=")
290-
else:
291-
key = arg
292-
i += 1
293-
value = argv[i]
294-
if "class_path" in value:
295-
# the user passed a config as a dict
296-
passed_configs[key] = yaml.safe_load(value)
297-
else:
298-
passed_args.append((key, value))
299-
else:
300-
clean_argv.append(arg)
301-
i += 1
302-
# generate the associated config file
303-
config = []
304-
i, n = 0, len(passed_args)
305-
while i < n - 1:
306-
ki, vi = passed_args[i]
307-
# convert class name to class path
308-
for cls in classes:
309-
if cls.__name__ == vi:
310-
cls_type = cls
311-
break
312-
else:
313-
raise ValueError(f"Could not generate a config for {repr(vi)}")
314-
config.append(_global_add_class_path(cls_type))
315-
# get any init args
316-
j = i + 1 # in case the j-loop doesn't run
317-
for j in range(i + 1, n):
318-
kj, vj = passed_args[j]
319-
if ki == kj:
320-
break
321-
if kj.startswith(ki):
322-
init_arg_name = kj.split(".")[-1]
323-
config[-1]["init_args"][init_arg_name] = vj
324-
i = j
325-
# update at the end to preserve the order
326-
for k, v in passed_configs.items():
327-
config.extend(v)
328-
if not config:
329-
return clean_argv
330-
return clean_argv + [argv_key, str(config)]
331-
332260

333261
class SaveConfigCallback(Callback):
334262
"""Saves a LightningCLI config to the log_dir when training starts.
@@ -648,7 +576,11 @@ def instantiate_trainer(self, **kwargs: Any) -> Trainer:
648576
return self._instantiate_trainer(trainer_config, extra_callbacks)
649577

650578
def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
651-
config["callbacks"] = config["callbacks"] or []
579+
if config["callbacks"] is None:
580+
config["callbacks"] = []
581+
elif not isinstance(config["callbacks"], list):
582+
config["callbacks"] = [config["callbacks"]]
583+
assert isinstance(config["callbacks"], list) # to handle mypy false positive
652584
config["callbacks"].extend(callbacks)
653585
if "callbacks" in self.trainer_defaults:
654586
if isinstance(self.trainer_defaults["callbacks"], list):

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ matplotlib>3.1, <3.5.3
44
torchtext>=0.9.*, <=0.12.0
55
omegaconf>=2.0.5, <=2.1.*
66
hydra-core>=1.0.5, <=1.1.*
7-
jsonargparse[signatures]>=4.8.0, <=4.8.0
7+
jsonargparse[signatures]>=4.9.0, <=4.9.0
88
gcsfs>=2021.5.0, <=2022.2.0
99
rich>=10.2.2,!=10.15.*, <=12.0.0

tests/utilities/test_cli.py

Lines changed: 13 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,14 @@ def on_fit_start(self):
318318
assert cli.trainer.ran_asserts
319319

320320

321+
def test_lightning_cli_single_arg_callback():
322+
with mock.patch("sys.argv", ["any.py", "--trainer.callbacks=DeviceStatsMonitor"]):
323+
cli = LightningCLI(BoringModel, run=False)
324+
325+
assert cli.config.trainer.callbacks.class_path == "pytorch_lightning.callbacks.DeviceStatsMonitor"
326+
assert not isinstance(cli.config_init.trainer, list)
327+
328+
321329
@pytest.mark.parametrize("run", (False, True))
322330
def test_lightning_cli_configurable_callbacks(tmpdir, run):
323331
class MyLightningCLI(LightningCLI):
@@ -1046,19 +1054,20 @@ def test_lightning_cli_datamodule_short_arguments():
10461054

10471055

10481056
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
1049-
def test_registries_resolution(use_class_path_callbacks):
1057+
def test_callbacks_append(use_class_path_callbacks):
10501058

10511059
"""This test validates registries are used when simplified command line are being used."""
10521060
cli_args = [
10531061
"--optimizer",
10541062
"Adam",
10551063
"--optimizer.lr",
10561064
"0.0001",
1057-
"--trainer.callbacks=LearningRateMonitor",
1065+
"--trainer.callbacks+=LearningRateMonitor",
10581066
"--trainer.callbacks.logging_interval=epoch",
10591067
"--trainer.callbacks.log_momentum=True",
10601068
"--model=BoringModel",
1061-
"--trainer.callbacks=ModelCheckpoint",
1069+
"--trainer.callbacks+",
1070+
"ModelCheckpoint",
10621071
"--trainer.callbacks.monitor=loss",
10631072
"--lr_scheduler",
10641073
"StepLR",
@@ -1071,7 +1080,7 @@ def test_registries_resolution(use_class_path_callbacks):
10711080
{"class_path": "pytorch_lightning.callbacks.Callback"},
10721081
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
10731082
]
1074-
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
1083+
cli_args += [f"--trainer.callbacks+={json.dumps(callbacks)}"]
10751084
extras = [Callback, Callback]
10761085

10771086
with mock.patch("sys.argv", ["any.py"] + cli_args), mock_subclasses(LightningModule, BoringModel):
@@ -1088,79 +1097,6 @@ def test_registries_resolution(use_class_path_callbacks):
10881097
assert all(t in callback_types for t in expected)
10891098

10901099

1091-
def test_argv_transformation_noop():
1092-
base = ["any.py", "--trainer.max_epochs=1"]
1093-
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base)
1094-
assert argv == base
1095-
1096-
1097-
def test_argv_transformation_single_callback():
1098-
base = ["any.py", "--trainer.max_epochs=1"]
1099-
input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]
1100-
callbacks = [
1101-
{
1102-
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
1103-
"init_args": {"monitor": "val_loss"},
1104-
}
1105-
]
1106-
expected = base + ["--trainer.callbacks", str(callbacks)]
1107-
_populate_registries(False)
1108-
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
1109-
assert argv == expected
1110-
1111-
1112-
def test_argv_transformation_multiple_callbacks():
1113-
base = ["any.py", "--trainer.max_epochs=1"]
1114-
input = base + [
1115-
"--trainer.callbacks=ModelCheckpoint",
1116-
"--trainer.callbacks.monitor=val_loss",
1117-
"--trainer.callbacks=ModelCheckpoint",
1118-
"--trainer.callbacks.monitor=val_acc",
1119-
]
1120-
callbacks = [
1121-
{
1122-
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
1123-
"init_args": {"monitor": "val_loss"},
1124-
},
1125-
{
1126-
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
1127-
"init_args": {"monitor": "val_acc"},
1128-
},
1129-
]
1130-
expected = base + ["--trainer.callbacks", str(callbacks)]
1131-
_populate_registries(False)
1132-
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
1133-
assert argv == expected
1134-
1135-
1136-
def test_argv_transformation_multiple_callbacks_with_config():
1137-
base = ["any.py", "--trainer.max_epochs=1"]
1138-
nested_key = "trainer.callbacks"
1139-
input = base + [
1140-
f"--{nested_key}=ModelCheckpoint",
1141-
f"--{nested_key}.monitor=val_loss",
1142-
f"--{nested_key}=ModelCheckpoint",
1143-
f"--{nested_key}.monitor=val_acc",
1144-
f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]",
1145-
]
1146-
callbacks = [
1147-
{
1148-
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
1149-
"init_args": {"monitor": "val_loss"},
1150-
},
1151-
{
1152-
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
1153-
"init_args": {"monitor": "val_acc"},
1154-
},
1155-
{"class_path": "pytorch_lightning.callbacks.Callback"},
1156-
]
1157-
expected = base + ["--trainer.callbacks", str(callbacks)]
1158-
nested_key = "trainer.callbacks"
1159-
_populate_registries(False)
1160-
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
1161-
assert argv == expected
1162-
1163-
11641100
def test_optimizers_and_lr_schedulers_reload(tmpdir):
11651101
base = ["any.py", "--trainer.max_epochs=1"]
11661102
input = base + [

0 commit comments

Comments
 (0)