Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def load_merge_configs(
private_home: Path | None = None,
from_run_id: str | None = None,
mini_epoch: int | None = None,
base: Path | Config | None = None,
*overwrites: Path | dict | Config,
) -> Config:
"""
Expand All @@ -298,6 +299,7 @@ def load_merge_configs(
from_run_id: Run id of the pretrained WeatherGenerator model
to continue training or inference
mini_epoch: Mini_epoch of the checkpoint to load. -1 indicates last checkpoint available.
base: Path to the base configuration file. Uses default configuration if None.
*overwrites: Additional overwrites from different sources

Note: The order of precedence for merging the final config is in ascending order:
Expand Down Expand Up @@ -328,7 +330,7 @@ def load_merge_configs(
private_config = set_paths(private_config)

if from_run_id is None:
base_config = _load_default_conf()
base_config = _load_base_conf(base)
else:
base_config = load_run_config(
from_run_id, mini_epoch, private_config.get("model_path", None)
Expand Down Expand Up @@ -496,11 +498,20 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig:
return private_cf


def _load_default_conf() -> Config:
"""Deserialize default configuration."""
c = OmegaConf.load(_DEFAULT_CONFIG_PTH)
assert isinstance(c, Config)
return c
def _load_base_conf(base: Path | Config | None) -> Config:
"""Return the base configuration"""
match base :
case Path():
_logger.info(f"Loading specified base config from file: {base}.")
conf = OmegaConf.load(base)
case Config():
_logger.info(f"Using existing config as base: {base}.")
conf = base
case _:
_logger.info("Deserialize default configuration.")
conf = OmegaConf.load(_DEFAULT_CONFIG_PTH)
assert isinstance(conf, Config)
return conf


def load_streams(streams_directory: Path) -> list[Config]:
Expand Down
4 changes: 3 additions & 1 deletion src/weathergen/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def inference_from_args(argl: list[str]):
args.private_config,
args.from_run_id,
args.mini_epoch,
args.base_config,
*args.config,
inference_overwrite,
cli_overwrite,
Expand Down Expand Up @@ -106,6 +107,7 @@ def train_continue_from_args(argl: list[str]):
args.from_run_id,
args.mini_epoch,
finetune_overwrite,
args.base_config,
*args.config,
cli_overwrite,
)
Expand Down Expand Up @@ -154,7 +156,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):

cli_overwrite = config.from_cli_arglist(args.options)

cf = config.load_merge_configs(args.private_config, None, None, *args.config, cli_overwrite)
cf = config.load_merge_configs(args.private_config, None, None, args.base_config, *args.config, cli_overwrite)
cf = config.set_run_id(cf, args.run_id, False)

cf.data_loader_rng_seed = int(time.time())
Expand Down
8 changes: 8 additions & 0 deletions src/weathergen/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def _add_general_arguments(parser: argparse.ArgumentParser):
" Individual items should be of the form: parent_obj.nested_obj=value"
),
)
parser.add_argument(
"--base-config",
type=Path,
help=(
"Path to the base configuration file."
"If not provided, ./config/default_config.yml is used."
)
)


def _add_model_loading_params(parser: argparse.ArgumentParser):
Expand Down
38 changes: 33 additions & 5 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
(pathlib.Path("#test.yml"), DUMMY_STREAM_CONF),
]

DUMMY_BASE_CONF = {
"foo": "bar"
}
Comment on lines +65 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can contain just one dummy key "foo": "bar" to distinguish if the test correctly loads the new base config instead of the default config.



def contains_keys(super_config, sub_config):
keys_present = [key in super_config.keys() for key in sub_config.keys()]
Expand Down Expand Up @@ -152,6 +156,17 @@ def config_fresh(private_config_file):

return cf

@pytest.fixture
def base_config():
return OmegaConf.create(DUMMY_BASE_CONF)

@pytest.fixure
def base_file(base_config):
with tempfile.NamedTemporaryFile("w+") as temp:
temp.write(OmegaConf.to_yaml(base_config))
temp.flush()
yield pathlib.Path(temp.name)


def test_contains_private(config_fresh):
sanitized_private_conf = DUMMY_PRIVATE_CONF.copy()
Expand All @@ -167,22 +182,35 @@ def test_is_paths_set(config_fresh):

@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
def test_load_with_overwrite_dict(overwrite_dict, private_config_file):
cf = config.load_merge_configs(private_config_file, None, None, overwrite_dict)
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_dict)

assert contains(cf, overwrite_dict)


@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
def test_load_with_overwrite_config(overwrite_config, private_config_file):
cf = config.load_merge_configs(private_config_file, None, None, overwrite_config)
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_config)

assert contains(cf, overwrite_config)


@pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True)
def test_load_with_overwrite_file(private_config_file, overwrite_file):
sub_cf = OmegaConf.load(overwrite_file)
cf = config.load_merge_configs(private_config_file, None, None, overwrite_file)
cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_file)

assert contains(cf, sub_cf)


def test_load_with_base_config(private_config_file, base_config):
cf = config.load_merge_configs(private_config_file, None, None, base_config)

assert contains(cf, base_config)


def test_load_with_base_file(private_config_file, base_file):
sub_cf = OmegaConf.load(base_file)
cf = config.load_merge_configs(private_config_file, None, None, base_file)

assert contains(cf, sub_cf)

Expand All @@ -191,7 +219,7 @@ def test_load_with_stream_in_overwrite(private_config_file, streams_dir, mocker)
overwrite = {"streams_directory": streams_dir}
stub = mocker.patch("weathergen.common.config.load_streams", return_value=streams_dir)

config.load_merge_configs(private_config_file, None, None, overwrite)
config.load_merge_configs(private_config_file, None, None, None, overwrite)

stub.assert_called_once_with(streams_dir)

Expand All @@ -200,7 +228,7 @@ def test_load_multiple_overwrites(private_config_file):
overwrites = [{"foo": 1, "bar": 1, "baz": 1}, {"foo": 2, "bar": 2}, {"foo": 3}]

expected = {"foo": 3, "bar": 2, "baz": 1}
cf = config.load_merge_configs(private_config_file, None, None, *overwrites)
cf = config.load_merge_configs(private_config_file, None, None, None, *overwrites)

assert contains(cf, expected)

Expand Down
Loading