diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index ba1678ed3..3069d1f73 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -240,15 +240,9 @@ def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path: Get the path to the model results zarr store from a given run_id and mini_epoch. """ run_results = Path(_load_private_conf(None)["path_shared_working_dir"]) / f"results/{run_id}" + zarr_path = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr" - zarr_path_new = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr" - zarr_path_old = run_results / f"validation_epoch{mini_epoch:05d}_rank{rank:04d}.zarr" - - if zarr_path_new.exists() or zarr_path_new.is_dir(): - zarr_path = zarr_path_new - elif zarr_path_old.exists() or zarr_path_old.is_dir(): - zarr_path = zarr_path_old - else: + if not (zarr_path.exists() or zarr_path.is_dir()): raise FileNotFoundError( f"Zarr file with run_id {run_id}, mini_epoch {mini_epoch} and rank {rank} does not " f"exist or is not a directory." @@ -341,11 +335,6 @@ def load_merge_configs( assert isinstance(c, Config) c = _add_interpolation(c) - # Ensure the config has mini-epoch notation - if hasattr(c, "samples_per_epoch"): - c.samples_per_mini_epoch = c.samples_per_epoch - c.num_mini_epochs = c.num_epochs - return c diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 03e4952ea..b10bb8a9c 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -71,17 +71,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") ) - fname_zarr_new = self.results_dir.joinpath( + self.fname_zarr = self.results_dir.joinpath( f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" ) - fname_zarr_old = self.results_dir.joinpath( - f"validation_epoch{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" - ) - - if fname_zarr_new.exists() or fname_zarr_new.is_dir(): - self.fname_zarr = fname_zarr_new - else: - self.fname_zarr = fname_zarr_old if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): _logger.error(f"Zarr file {self.fname_zarr} does not exist.") diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 5649b8000..e106579c4 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -172,10 +172,6 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): ) filename = f"{run_id}_{mini_epoch_id}.chkpt" - if not (path_run / filename).exists(): - mini_epoch_id = f"epoch{mini_epoch:05d}" - filename = f"{run_id}_{mini_epoch_id}.chkpt" - params = torch.load( path_run / filename, map_location=torch.device("cpu"), mmap=True, weights_only=True )