Skip to content

Commit 4e08de2

Browse files
kaushikb11lexierule
authored andcommitted
[CLI] Add support for --key.help=class
1 parent bf9aef7 commit 4e08de2

File tree

4 files changed

+55
-8
lines changed

4 files changed

+55
-8
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Fixed
1111

12+
- Fixed support for `--key.help=class` with the `LightningCLI` ([#10767](https://github.com/PyTorchLightning/pytorch-lightning/pull/10767))
13+
14+
1215
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))
1316

1417

pytorch_lightning/utilities/cli.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,27 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis
265265
else:
266266
clean_argv.append(arg)
267267
i += 1
268+
269+
# the user requested a help message
270+
help_key = argv_key + ".help"
271+
if help_key in passed_args:
272+
argv_class = passed_args[help_key]
273+
if "." in argv_class:
274+
# user passed the class path directly
275+
class_path = argv_class
276+
else:
277+
# convert shorthand format to the classpath
278+
for cls in classes:
279+
if cls.__name__ == argv_class:
280+
class_path = _class_path_from_class(cls)
281+
break
282+
else:
283+
raise ValueError(f"Could not generate get the class_path for {repr(argv_class)}")
284+
return clean_argv + [help_key, class_path]
285+
268286
# generate the associated config file
269-
argv_class = passed_args.pop(argv_key, None)
270-
if argv_class is None:
287+
argv_class = passed_args.pop(argv_key, "")
288+
if not argv_class:
271289
# the user passed a config as a str
272290
class_path = passed_args[f"{argv_key}.class_path"]
273291
init_args_key = f"{argv_key}.init_args"
@@ -772,8 +790,16 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]:
772790
return fn_kwargs
773791

774792

775-
def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]:
776-
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}}
793+
def _class_path_from_class(class_type: Type) -> str:
794+
return class_type.__module__ + "." + class_type.__name__
795+
796+
797+
def _global_add_class_path(
798+
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
799+
) -> Dict[str, Any]:
800+
if isinstance(init_args, Namespace):
801+
init_args = init_args.as_dict()
802+
return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}}
777803

778804

779805
def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]:

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta
55
torchtext>=0.8.*
66
omegaconf>=2.0.5
77
hydra-core>=1.0.5
8-
jsonargparse[signatures]>=4.0.0
8+
jsonargparse[signatures]>=4.0.4
99
gcsfs>=2021.5.0
1010
rich>=10.2.2

tests/utilities/test_cli.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858

5959
@mock.patch("argparse.ArgumentParser.parse_args")
60-
def test_default_args(mock_argparse, tmpdir):
60+
def test_default_args(mock_argparse):
6161
"""Tests default argument parser for Trainer."""
6262
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
6363

@@ -868,7 +868,7 @@ class CustomCallback(Callback):
868868
pass
869869

870870

871-
def test_registries(tmpdir):
871+
def test_registries():
872872
assert "SGD" in OPTIMIZER_REGISTRY.names
873873
assert "RMSprop" in OPTIMIZER_REGISTRY.names
874874
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
@@ -1358,9 +1358,27 @@ class TestCallback(Callback):
13581358
assert cli.config_init["trainer"]["max_epochs"] is None
13591359

13601360

1361-
def test_cli_configure_optimizers_warning(tmpdir):
1361+
def test_cli_configure_optimizers_warning():
13621362
match = "configure_optimizers` will be overridden by `LightningCLI"
13631363
with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match):
13641364
LightningCLI(BoringModel, run=False)
13651365
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match):
13661366
LightningCLI(BoringModel, run=False)
1367+
1368+
1369+
def test_cli_help_message():
1370+
# full class path
1371+
cli_args = ["any.py", "--optimizer.help=torch.optim.Adam"]
1372+
classpath_help = StringIO()
1373+
with mock.patch("sys.argv", cli_args), redirect_stdout(classpath_help), pytest.raises(SystemExit):
1374+
LightningCLI(BoringModel, run=False)
1375+
1376+
cli_args = ["any.py", "--optimizer.help=Adam"]
1377+
shorthand_help = StringIO()
1378+
with mock.patch("sys.argv", cli_args), redirect_stdout(shorthand_help), pytest.raises(SystemExit):
1379+
LightningCLI(BoringModel, run=False)
1380+
1381+
# the help messages should match
1382+
assert shorthand_help.getvalue() == classpath_help.getvalue()
1383+
# make sure it's not empty
1384+
assert "Implements Adam" in shorthand_help.getvalue()

0 commit comments

Comments
 (0)