Skip to content

Commit 6d5e22d

Browse files
carmoccalexierule
authored andcommitted
[CLI] Save only the configuration used (#11532)
1 parent 9811284 commit 6d5e22d

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Pin sphinx-autodoc-typehints with <v1.15 ([#11400](https://github.com/PyTorchLightning/pytorch-lightning/pull/11400))
1212
- Skip testing with PyTorch 1.7 and Python 3.9 on Ubuntu ([#11217](https://github.com/PyTorchLightning/pytorch-lightning/pull/11217))
1313
- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401))
14+
- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532))
1415

1516
### Changed
1617

1718
- Changed `LSFEnvironment` to use `LSB_DJOB_RANKFILE` environment variable instead of `LSB_HOSTS` for determining node rank and main address ([#10825](https://github.com/PyTorchLightning/pytorch-lightning/pull/10825))
18-
19-
2019
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
2120

2221

pytorch_lightning/utilities/cli.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,8 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]
685685
config["callbacks"].append(self.trainer_defaults["callbacks"])
686686
if self.save_config_callback and not config["fast_dev_run"]:
687687
config_callback = self.save_config_callback(
688-
self.parser,
689-
self.config,
688+
self._parser(self.subcommand),
689+
self.config.get(str(self.subcommand), self.config),
690690
self.save_config_filename,
691691
overwrite=self.save_config_overwrite,
692692
multifile=self.save_config_multifile,
@@ -769,9 +769,7 @@ def configure_optimizers(
769769

770770
def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any:
771771
"""Utility to get a config value which might be inside a subcommand."""
772-
if self.subcommand is not None:
773-
return config[self.subcommand].get(key, default)
774-
return config.get(key, default)
772+
return config.get(str(self.subcommand), config).get(key, default)
775773

776774
def _run_subcommand(self, subcommand: str) -> None:
777775
"""Run the chosen subcommand."""

tests/utilities/test_cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,7 @@ def test_lightning_cli_args(tmpdir):
347347
with open(config_path) as f:
348348
loaded_config = yaml.safe_load(f.read())
349349

350-
loaded_config = loaded_config["fit"]
351350
cli_config = cli.config["fit"]
352-
353351
assert cli_config["seed_everything"] == 1234
354352
assert "model" not in loaded_config and "model" not in cli_config # no arguments to include
355353
assert loaded_config["data"] == cli_config["data"]
@@ -403,9 +401,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
403401
with open(config_path) as f:
404402
loaded_config = yaml.safe_load(f.read())
405403

406-
loaded_config = loaded_config["fit"]
407404
cli_config = cli.config["fit"]
408-
409405
assert loaded_config["model"] == cli_config["model"]
410406
assert loaded_config["data"] == cli_config["data"]
411407
assert loaded_config["trainer"] == cli_config["trainer"]
@@ -1251,6 +1247,10 @@ def test_lightning_cli_config_before_subcommand():
12511247
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
12521248
assert cli.trainer.limit_test_batches == 1
12531249

1250+
save_config_callback = cli.trainer.callbacks[0]
1251+
assert save_config_callback.config["trainer"]["limit_test_batches"] == 1
1252+
assert save_config_callback.parser.subcommand == "test"
1253+
12541254
with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch(
12551255
"pytorch_lightning.Trainer.validate", autospec=True
12561256
) as validate_mock:

0 commit comments

Comments
 (0)