Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
matplotlib>3.1, <3.10.0
omegaconf >=2.2.3, <2.4.0
hydra-core >=1.2.0, <1.4.0
jsonargparse[signatures] >=4.39.0, <4.41.0
jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0
rich >=12.3.0, <14.1.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"
4 changes: 3 additions & 1 deletion src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
return
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
self.config_dump = yaml.safe_load(
self.parser.dump(self.config, skip_link_targets=False, skip_none=False, format="yaml")
)
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

Expand Down
20 changes: 20 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,3 +1858,23 @@ def test_lightning_cli_with_args_given(args):
def test_lightning_cli_args_and_sys_argv_warning():
with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.warns(Warning, match="LightningCLI's args parameter "):
LightningCLI(TestModel, run=False, args=["--model.foo=789"])


def test_lightning_cli_jsonnet(cleandir):
class MainModule(BoringModel):
def __init__(self, main_param: int = 1):
super().__init__()

config = """{
"model":{
"main_param": 2
}
}"""
config_path = Path("config.jsonnet")
config_path.write_text(config)

cli_args = [f"--config={config_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})

assert cli.config["model"]["main_param"] == 2
Loading