Skip to content

Commit aeb9e29

Browse files
committed
Add message about ckpt_path hyperparameters when parsing fails
1 parent 9061e15 commit aeb9e29

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

src/lightning/pytorch/cli.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,11 @@ def _parse_ckpt_path(self) -> None:
564564
hparams.pop("_instantiator", None)
565565
if hparams:
566566
hparams = {self.config.subcommand: {"model": hparams}}
567-
self.config = self.parser.parse_object(hparams, self.config)
567+
try:
568+
self.config = self.parser.parse_object(hparams, self.config)
569+
except SystemExit:
570+
sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n")
571+
raise
568572

569573
def _dump_config(self) -> None:
570574
if hasattr(self, "config_dump"):

tests/tests_pytorch/test_cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import operator
1818
import os
1919
import sys
20-
from contextlib import ExitStack, contextmanager, redirect_stdout
20+
from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout
2121
from io import StringIO
2222
from pathlib import Path
2323
from typing import Callable, Optional, Union
@@ -520,6 +520,11 @@ def add_arguments_to_parser(self, parser):
520520
assert cli.config.predict.model.hidden_dim == 6
521521
assert cli.config_init.predict.model.layer.out_features == 3
522522

523+
err = StringIO()
524+
with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit):
525+
cli = LightningCLI(BoringModel)
526+
assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue()
527+
523528

524529
def test_lightning_cli_submodules(cleandir):
525530
class MainModule(BoringModel):

0 commit comments

Comments
 (0)