diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 2579b701f1faf..0f11b19c23431 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,7 +5,7 @@ matplotlib>3.1, <3.10.0 omegaconf >=2.2.3, <2.4.0 hydra-core >=1.2.0, <1.4.0 -jsonargparse[signatures] >=4.39.0, <4.41.0 +jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0 rich >=12.3.0, <14.1.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin" diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2a6226c42cde1..225296240674a 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -554,7 +554,9 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No def _dump_config(self) -> None: if hasattr(self, "config_dump"): return - self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False)) + self.config_dump = yaml.safe_load( + self.parser.dump(self.config, skip_link_targets=False, skip_none=False, format="yaml") + ) if "subcommand" in self.config: self.config_dump = self.config_dump[self.config.subcommand] diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 59ce4cfe4bb71..1b883dda0282a 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1858,3 +1858,23 @@ def test_lightning_cli_with_args_given(args): def test_lightning_cli_args_and_sys_argv_warning(): with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.warns(Warning, match="LightningCLI's args parameter "): LightningCLI(TestModel, run=False, args=["--model.foo=789"]) + + +def test_lightning_cli_jsonnet(cleandir): + class MainModule(BoringModel): + def __init__(self, main_param: int = 1): + super().__init__() + + config = """{ + "model":{ + "main_param": 2 + } + }""" + config_path = Path("config.jsonnet") + config_path.write_text(config) + + cli_args = [f"--config={config_path}"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"}) + + assert cli.config["model"]["main_param"] == 2