Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
return
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
self.config_dump = yaml.safe_load(
self.parser.dump(self.config, skip_link_targets=False, skip_none=False, format="yaml")
)
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

Expand Down
20 changes: 20 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,3 +1858,23 @@ def test_lightning_cli_with_args_given(args):
def test_lightning_cli_args_and_sys_argv_warning():
with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.warns(Warning, match="LightningCLI's args parameter "):
LightningCLI(TestModel, run=False, args=["--model.foo=789"])


def test_lightning_cli_jsonnet(cleandir):
class MainModule(BoringModel):
def __init__(self, main_param: int = 1):
super().__init__()

config = """{
"model":{
"main_param": 2
}
}"""
config_path = Path("config.jsonnet")
config_path.write_text(config)

cli_args = [f"--config={config_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})

assert cli.config["model"]["main_param"] == 2
Loading