diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index ba1678ed3..26d1709b2 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -287,6 +287,7 @@ def load_merge_configs( private_home: Path | None = None, from_run_id: str | None = None, mini_epoch: int | None = None, + base: Path | Config | None = None, *overwrites: Path | dict | Config, ) -> Config: """ @@ -298,6 +299,7 @@ def load_merge_configs( from_run_id: Run id of the pretrained WeatherGenerator model to continue training or inference mini_epoch: Mini_epoch of the checkpoint to load. -1 indicates last checkpoint available. + base: Path to the base configuration file. Uses default configuration if None. *overwrites: Additional overwrites from different sources Note: The order of precedence for merging the final config is in ascending order: @@ -328,7 +330,7 @@ def load_merge_configs( private_config = set_paths(private_config) if from_run_id is None: - base_config = _load_default_conf() + base_config = _load_base_conf(base) else: base_config = load_run_config( from_run_id, mini_epoch, private_config.get("model_path", None) @@ -496,11 +498,20 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig: return private_cf -def _load_default_conf() -> Config: - """Deserialize default configuration.""" - c = OmegaConf.load(_DEFAULT_CONFIG_PTH) - assert isinstance(c, Config) - return c +def _load_base_conf(base: Path | Config | None) -> Config: + """Return the base configuration""" + match base : + case Path(): + _logger.info(f"Loading specified base config from file: {base}.") + conf = OmegaConf.load(base) + case Config(): + _logger.info(f"Using existing config as base: {base}.") + conf = base + case _: + _logger.info("Deserialize default configuration.") + conf = OmegaConf.load(_DEFAULT_CONFIG_PTH) + assert isinstance(conf, Config) + return conf def load_streams(streams_directory: Path) -> list[Config]: diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index b7c2ef5f9..ff2b5a518 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -55,6 +55,7 @@ def inference_from_args(argl: list[str]): args.private_config, args.from_run_id, args.mini_epoch, + args.base_config, *args.config, inference_overwrite, cli_overwrite, @@ -106,6 +107,7 @@ def train_continue_from_args(argl: list[str]): args.from_run_id, args.mini_epoch, finetune_overwrite, + args.base_config, *args.config, cli_overwrite, ) @@ -154,7 +156,7 @@ def train_with_args(argl: list[str], stream_dir: str | None): cli_overwrite = config.from_cli_arglist(args.options) - cf = config.load_merge_configs(args.private_config, None, None, *args.config, cli_overwrite) + cf = config.load_merge_configs(args.private_config, None, None, args.base_config, *args.config, cli_overwrite) cf = config.set_run_id(cf, args.run_id, False) cf.data_loader_rng_seed = int(time.time()) diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index 8f6f7eed1..fea1760f4 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -113,6 +113,14 @@ def _add_general_arguments(parser: argparse.ArgumentParser): " Individual items should be of the form: parent_obj.nested_obj=value" ), ) + parser.add_argument( + "--base-config", + type=Path, + help=( + "Path to the base configuration file." + "If not provided, ./config/default_config.yml is used." + ) + ) def _add_model_loading_params(parser: argparse.ArgumentParser): diff --git a/tests/test_config.py b/tests/test_config.py index e5ba9cbf4..c5fa2e5f3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -62,6 +62,10 @@ (pathlib.Path("#test.yml"), DUMMY_STREAM_CONF), ] +DUMMY_BASE_CONF = { + "foo": "bar" +} + def contains_keys(super_config, sub_config): keys_present = [key in super_config.keys() for key in sub_config.keys()] @@ -152,6 +156,17 @@ def config_fresh(private_config_file): return cf +@pytest.fixture +def base_config(): + return OmegaConf.create(DUMMY_BASE_CONF) + +@pytest.fixure +def base_file(base_config): + with tempfile.NamedTemporaryFile("w+") as temp: + temp.write(OmegaConf.to_yaml(base_config)) + temp.flush() + yield pathlib.Path(temp.name) + def test_contains_private(config_fresh): sanitized_private_conf = DUMMY_PRIVATE_CONF.copy() @@ -167,14 +182,14 @@ def test_is_paths_set(config_fresh): @pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True) def test_load_with_overwrite_dict(overwrite_dict, private_config_file): - cf = config.load_merge_configs(private_config_file, None, None, overwrite_dict) + cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_dict) assert contains(cf, overwrite_dict) @pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True) def test_load_with_overwrite_config(overwrite_config, private_config_file): - cf = config.load_merge_configs(private_config_file, None, None, overwrite_config) + cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_config) assert contains(cf, overwrite_config) @@ -182,7 +197,20 @@ def test_load_with_overwrite_config(overwrite_config, private_config_file): @pytest.mark.parametrize("overwrite_dict", DUMMY_OVERWRITES, indirect=True) def test_load_with_overwrite_file(private_config_file, overwrite_file): sub_cf = OmegaConf.load(overwrite_file) - cf = config.load_merge_configs(private_config_file, None, None, overwrite_file) + cf = config.load_merge_configs(private_config_file, None, None, None, overwrite_file) + + assert contains(cf, sub_cf) + + +def test_load_with_base_config(private_config_file, base_config): + cf = config.load_merge_configs(private_config_file, None, None, base_config) + + assert contains(cf, base_config) + + +def test_load_with_base_file(private_config_file, base_file): + sub_cf = OmegaConf.load(base_file) + cf = config.load_merge_configs(private_config_file, None, None, base_file) assert contains(cf, sub_cf) @@ -191,7 +219,7 @@ def test_load_with_stream_in_overwrite(private_config_file, streams_dir, mocker) overwrite = {"streams_directory": streams_dir} stub = mocker.patch("weathergen.common.config.load_streams", return_value=streams_dir) - config.load_merge_configs(private_config_file, None, None, overwrite) + config.load_merge_configs(private_config_file, None, None, None, overwrite) stub.assert_called_once_with(streams_dir) @@ -200,7 +228,7 @@ def test_load_multiple_overwrites(private_config_file): overwrites = [{"foo": 1, "bar": 1, "baz": 1}, {"foo": 2, "bar": 2}, {"foo": 3}] expected = {"foo": 3, "bar": 2, "baz": 1} - cf = config.load_merge_configs(private_config_file, None, None, *overwrites) + cf = config.load_merge_configs(private_config_file, None, None, None, *overwrites) assert contains(cf, expected)