Skip to content

Commit 85d6d54

Browse files
committed
set base config (#1539)
1 parent a7ae70d commit 85d6d54

File tree

4 files changed

+59
-10
lines changed

4 files changed

+59
-10
lines changed

packages/common/src/weathergen/common/config.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def load_merge_configs(
287287
private_home: Path | None = None,
288288
from_run_id: str | None = None,
289289
mini_epoch: int | None = None,
290+
base: Path | Config | None = None,
290291
*overwrites: Path | dict | Config,
291292
) -> Config:
292293
"""
@@ -298,6 +299,7 @@ def load_merge_configs(
298299
from_run_id: Run id of the pretrained WeatherGenerator model
299300
to continue training or inference
300301
mini_epoch: Mini_epoch of the checkpoint to load. -1 indicates last checkpoint available.
302+
base: Path to the base configuration file. Uses default configuration if None.
301303
*overwrites: Additional overwrites from different sources
302304
303305
Note: The order of precedence for merging the final config is in ascending order:
@@ -328,7 +330,7 @@ def load_merge_configs(
328330
private_config = set_paths(private_config)
329331

330332
if from_run_id is None:
331-
base_config = _load_default_conf()
333+
base_config = _load_base_conf(base)
332334
else:
333335
base_config = load_run_config(
334336
from_run_id, mini_epoch, private_config.get("model_path", None)
@@ -496,9 +498,18 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig:
496498
return private_cf
497499

498500

499-
def _load_default_conf() -> Config:
500-
"""Deserialize default configuration."""
501-
c = OmegaConf.load(_DEFAULT_CONFIG_PTH)
501+
def _load_base_conf(base: Path | Config | None) -> Config:
502+
"""Return the base configuration"""
503+
match base :
504+
case Path():
505+
_logger.info(f"Loading specified base config from file: {base}.")
506+
c = OmegaConf.load(base)
507+
case Config():
508+
_logger.info(f"Using existing config as base: {base}.")
509+
c = base
510+
case _:
511+
_logger.info("Deserialize default configuration.")
512+
c = OmegaConf.load(_DEFAULT_CONFIG_PTH)
502513
assert isinstance(c, Config)
503514
return c
504515

src/weathergen/run_train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def inference_from_args(argl: list[str]):
5555
args.private_config,
5656
args.from_run_id,
5757
args.mini_epoch,
58+
args.base_config,
5859
*args.config,
5960
inference_overwrite,
6061
cli_overwrite,
@@ -106,6 +107,7 @@ def train_continue_from_args(argl: list[str]):
106107
args.from_run_id,
107108
args.mini_epoch,
108109
finetune_overwrite,
110+
args.base_config,
109111
*args.config,
110112
cli_overwrite,
111113
)
@@ -154,7 +156,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):
154156

155157
cli_overwrite = config.from_cli_arglist(args.options)
156158

157-
cf = config.load_merge_configs(args.private_config, None, None, *args.config, cli_overwrite)
159+
cf = config.load_merge_configs(args.private_config, None, None, args.base_config, *args.config, cli_overwrite)
158160
cf = config.set_run_id(cf, args.run_id, False)
159161

160162
cf.data_loader_rng_seed = int(time.time())

src/weathergen/utils/cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ def _add_general_arguments(parser: argparse.ArgumentParser):
113113
" Individual items should be of the form: parent_obj.nested_obj=value"
114114
),
115115
)
116+
parser.add_argument(
117+
"--base-config",
118+
type=Path,
119+
help=(
120+
"Path to the base configuration file."
121+
"If not provided, the default configuration is used."
122+
)
123+
)
116124

117125

118126
def _add_model_loading_params(parser: argparse.ArgumentParser):

tests/test_config.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
(pathlib.Path("#test.yml"), DUMMY_STREAM_CONF),
6363
]
6464

65+
DUMMY_BASE_CONF = {
66+
# TODO add base configuration
67+
}
68+
6569

6670
def contains_keys(super_config, sub_config):
6771
keys_present = [key in super_config.keys() for key in sub_config.keys()]
@@ -152,6 +156,17 @@ def config_fresh(private_config_file):
152156

153157
return cf
154158

159+
@pytest.fixture
160+
def base_conf():
161+
return OmegaConf.create(DUMMY_BASE_CONF)
162+
163+
@pytest.fixure
164+
def base_file(base_conf):
165+
with tempfile.NamedTemporaryFile("w+") as temp:
166+
temp.write(OmegaConf.to_yaml(base_conf))
167+
temp.flush()
168+
yield pathlib.Path(temp.name)
169+
155170

156171
def test_contains_private(config_fresh):
157172
sanitized_private_conf = DUMMY_PRIVATE_CONF.copy()
@@ -167,22 +182,35 @@ def test_is_paths_set(config_fresh):
167182

168183
@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
169184
def test_load_with_overwrite_dict(overwrite_dict, private_config_file):
170-
cf = config.load_merge_configs(private_config_file, None, None, overwrite_dict)
185+
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_dict)
171186

172187
assert contains(cf, overwrite_dict)
173188

174189

175190
@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
176191
def test_load_with_overwrite_config(overwrite_config, private_config_file):
177-
cf = config.load_merge_configs(private_config_file, None, None, overwrite_config)
192+
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_config)
178193

179194
assert contains(cf, overwrite_config)
180195

181196

182197
@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
183198
def test_load_with_overwrite_file(private_config_file, overwrite_file):
184199
sub_cf = OmegaConf.load(overwrite_file)
185-
cf = config.load_merge_configs(private_config_file, None, None, overwrite_file)
200+
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_file)
201+
202+
assert contains(cf, sub_cf)
203+
204+
205+
def test_load_with_base_config(base_config, private_config_file):
206+
cf = config.load_merge_configs(private_config_file, None, None, base_config)
207+
208+
assert contains(cf, base_config)
209+
210+
211+
def test_load_with_base_file(base_config_file, private_config_file):
212+
sub_cf = OmegaConf.load(overwrite_file)
213+
cf = config.load_merge_configs(private_config_file, None, None, base_config_file)
186214

187215
assert contains(cf, sub_cf)
188216

@@ -191,7 +219,7 @@ def test_load_with_stream_in_overwrite(private_config_file, streams_dir, mocker)
191219
overwrite = {"streams_directory": streams_dir}
192220
stub = mocker.patch("weathergen.common.config.load_streams", return_value=streams_dir)
193221

194-
config.load_merge_configs(private_config_file, None, None, overwrite)
222+
config.load_merge_configs(private_config_file, None, None, None, overwrite)
195223

196224
stub.assert_called_once_with(streams_dir)
197225

@@ -200,7 +228,7 @@ def test_load_multiple_overwrites(private_config_file):
200228
overwrites = [{"foo": 1, "bar": 1, "baz": 1}, {"foo": 2, "bar": 2}, {"foo": 3}]
201229

202230
expected = {"foo": 3, "bar": 2, "baz": 1}
203-
cf = config.load_merge_configs(private_config_file, None, None, *overwrites)
231+
cf = config.load_merge_configs(private_config_file, None, None, None, *overwrites)
204232

205233
assert contains(cf, expected)
206234

0 commit comments

Comments
 (0)