Skip to content

Commit 09bf9db

Browse files
mtarclessig
authored andcommitted
[1539][infra] Adds base config flag (ecmwf#1573)
* set base config (ecmwf#1539) * update help message * longer variable name * longer variable name * rename config variable * rename base_configs --------- Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
1 parent 44cab76 commit 09bf9db

File tree

4 files changed

+61
-12
lines changed

4 files changed

+61
-12
lines changed

packages/common/src/weathergen/common/config.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def load_merge_configs(
281281
private_home: Path | None = None,
282282
from_run_id: str | None = None,
283283
mini_epoch: int | None = None,
284+
base: Path | Config | None = None,
284285
*overwrites: Path | dict | Config,
285286
) -> Config:
286287
"""
@@ -292,6 +293,7 @@ def load_merge_configs(
292293
from_run_id: Run id of the pretrained WeatherGenerator model
293294
to continue training or inference
294295
mini_epoch: Mini_epoch of the checkpoint to load. -1 indicates last checkpoint available.
296+
base: Path to the base configuration file. Uses default configuration if None.
295297
*overwrites: Additional overwrites from different sources
296298
297299
Note: The order of precedence for merging the final config is in ascending order:
@@ -322,7 +324,7 @@ def load_merge_configs(
322324
private_config = set_paths(private_config)
323325

324326
if from_run_id is None:
325-
base_config = _load_default_conf()
327+
base_config = _load_base_conf(base)
326328
else:
327329
base_config = load_run_config(
328330
from_run_id, mini_epoch, private_config.get("model_path", None)
@@ -485,11 +487,20 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig:
485487
return private_cf
486488

487489

488-
def _load_default_conf() -> Config:
489-
"""Deserialize default configuration."""
490-
c = OmegaConf.load(_DEFAULT_CONFIG_PTH)
491-
assert isinstance(c, Config)
492-
return c
490+
def _load_base_conf(base: Path | Config | None) -> Config:
491+
"""Return the base configuration"""
492+
match base :
493+
case Path():
494+
_logger.info(f"Loading specified base config from file: {base}.")
495+
conf = OmegaConf.load(base)
496+
case Config():
497+
_logger.info(f"Using existing config as base: {base}.")
498+
conf = base
499+
case _:
500+
_logger.info("Deserialize default configuration.")
501+
conf = OmegaConf.load(_DEFAULT_CONFIG_PTH)
502+
assert isinstance(conf, Config)
503+
return conf
493504

494505

495506
def load_streams(streams_directory: Path) -> list[Config]:

src/weathergen/run_train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def inference_from_args(argl: list[str]):
5555
args.private_config,
5656
args.from_run_id,
5757
args.mini_epoch,
58+
args.base_config,
5859
*args.config,
5960
inference_overwrite,
6061
cli_overwrite,
@@ -106,6 +107,7 @@ def train_continue_from_args(argl: list[str]):
106107
args.from_run_id,
107108
args.mini_epoch,
108109
finetune_overwrite,
110+
args.base_config,
109111
*args.config,
110112
cli_overwrite,
111113
)
@@ -154,7 +156,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):
154156

155157
cli_overwrite = config.from_cli_arglist(args.options)
156158

157-
cf = config.load_merge_configs(args.private_config, None, None, *args.config, cli_overwrite)
159+
cf = config.load_merge_configs(args.private_config, None, None, args.base_config, *args.config, cli_overwrite)
158160
cf = config.set_run_id(cf, args.run_id, False)
159161

160162
cf.data_loader_rng_seed = int(time.time())

src/weathergen/utils/cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ def _add_general_arguments(parser: argparse.ArgumentParser):
113113
" Individual items should be of the form: parent_obj.nested_obj=value"
114114
),
115115
)
116+
parser.add_argument(
117+
"--base-config",
118+
type=Path,
119+
help=(
120+
"Path to the base configuration file."
121+
"If not provided, ./config/default_config.yml is used."
122+
)
123+
)
116124

117125

118126
def _add_model_loading_params(parser: argparse.ArgumentParser):

tests/test_config.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
(pathlib.Path("#test.yml"), DUMMY_STREAM_CONF),
6363
]
6464

65+
DUMMY_BASE_CONF = {
66+
"foo": "bar"
67+
}
68+
6569

6670
def contains_keys(super_config, sub_config):
6771
keys_present = [key in super_config.keys() for key in sub_config.keys()]
@@ -152,6 +156,17 @@ def config_fresh(private_config_file):
152156

153157
return cf
154158

159+
@pytest.fixture
160+
def base_config():
161+
return OmegaConf.create(DUMMY_BASE_CONF)
162+
163+
@pytest.fixure
164+
def base_file(base_config):
165+
with tempfile.NamedTemporaryFile("w+") as temp:
166+
temp.write(OmegaConf.to_yaml(base_config))
167+
temp.flush()
168+
yield pathlib.Path(temp.name)
169+
155170

156171
def test_contains_private(config_fresh):
157172
sanitized_private_conf = DUMMY_PRIVATE_CONF.copy()
@@ -167,22 +182,35 @@ def test_is_paths_set(config_fresh):
167182

168183
@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
169184
def test_load_with_overwrite_dict(overwrite_dict, private_config_file):
170-
cf = config.load_merge_configs(private_config_file, None, None, overwrite_dict)
185+
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_dict)
171186

172187
assert contains(cf, overwrite_dict)
173188

174189

175190
@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
176191
def test_load_with_overwrite_config(overwrite_config, private_config_file):
177-
cf = config.load_merge_configs(private_config_file, None, None, overwrite_config)
192+
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_config)
178193

179194
assert contains(cf, overwrite_config)
180195

181196

182197
@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
183198
def test_load_with_overwrite_file(private_config_file, overwrite_file):
184199
sub_cf = OmegaConf.load(overwrite_file)
185-
cf = config.load_merge_configs(private_config_file, None, None, overwrite_file)
200+
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_file)
201+
202+
assert contains(cf, sub_cf)
203+
204+
205+
def test_load_with_base_config(private_config_file, base_config):
206+
cf = config.load_merge_configs(private_config_file, None, None, base_config)
207+
208+
assert contains(cf, base_config)
209+
210+
211+
def test_load_with_base_file(private_config_file, base_file):
212+
sub_cf = OmegaConf.load(base_file)
213+
cf = config.load_merge_configs(private_config_file, None, None, base_file)
186214

187215
assert contains(cf, sub_cf)
188216

@@ -191,7 +219,7 @@ def test_load_with_stream_in_overwrite(private_config_file, streams_dir, mocker)
191219
overwrite = {"streams_directory": streams_dir}
192220
stub = mocker.patch("weathergen.common.config.load_streams", return_value=streams_dir)
193221

194-
config.load_merge_configs(private_config_file, None, None, overwrite)
222+
config.load_merge_configs(private_config_file, None, None, None, overwrite)
195223

196224
stub.assert_called_once_with(streams_dir)
197225

@@ -200,7 +228,7 @@ def test_load_multiple_overwrites(private_config_file):
200228
overwrites = [{"foo": 1, "bar": 1, "baz": 1}, {"foo": 2, "bar": 2}, {"foo": 3}]
201229

202230
expected = {"foo": 3, "bar": 2, "baz": 1}
203-
cf = config.load_merge_configs(private_config_file, None, None, *overwrites)
231+
cf = config.load_merge_configs(private_config_file, None, None, None, *overwrites)
204232

205233
assert contains(cf, expected)
206234

0 commit comments

Comments
 (0)