diff --git a/src/spinup_evaluation/cli.py b/src/spinup_evaluation/cli.py index d430124..203c201 100644 --- a/src/spinup_evaluation/cli.py +++ b/src/spinup_evaluation/cli.py @@ -145,7 +145,7 @@ def apply_metrics_restart(data: xr.Dataset, mask: xr.Dataset) -> Dict[str, Any]: ), # Name suggests temperature, but kept as-is from original code. "temperature_DWbox_metric": lambda d: temperature_DWbox_metric( - d["velocity_u"][0], mask + d["temperature"][0], mask ), "ACC_Drake_metric": lambda d: ACC_Drake_metric(d["velocity_u"][0], mask), "ACC_Drake_metric_2": lambda d: ACC_Drake_metric_2( @@ -187,7 +187,7 @@ def apply_metrics_output( ), # Name suggests temperature, but kept aligned with original mapping: "temperature_DWbox_metric": lambda: temperature_DWbox_metric( - grid_output["velocity_u"], mask + grid_output["temperature"], mask ), "ACC_Drake_metric": lambda: ACC_Drake_metric(grid_output["velocity_u"], mask), "ACC_Drake_metric_2": lambda: ACC_Drake_metric_2( diff --git a/src/spinup_evaluation/loader.py b/src/spinup_evaluation/loader.py index 1fa6a24..511f168 100644 --- a/src/spinup_evaluation/loader.py +++ b/src/spinup_evaluation/loader.py @@ -2,6 +2,7 @@ import glob import os +from pathlib import Path from typing import Dict, Mapping, Optional, Union import xarray as xr @@ -120,8 +121,26 @@ def _normalise_var_specs( return normalised -def load_mesh_mask(path: str) -> xr.Dataset: +def resolve_mesh_mask(mesh_mask: str, sim_path: str) -> Path: + """Resolve the mesh mask path, handling absolute and relative paths.""" + p = Path(mesh_mask) + candidate = p if p.is_absolute() else Path(sim_path) / mesh_mask + if not candidate.exists(): + hint = ( + "Set `mesh_mask` to an absolute path in your YAML, " + "or ensure it exists under --sim-path." + ) + msg = f"Mesh mask file not found: {candidate}. {hint}" + raise FileNotFoundError(msg) + + return candidate + + +def load_mesh_mask(path: Path) -> xr.Dataset: """Load the NEMO mesh mask file and validate required fields.""" + if not path.exists(): + msg = f"Mesh mask file not found: {path}" + raise FileNotFoundError(msg) ds = xr.open_dataset(path) required_vars = ["tmask", "e1t", "e2t", "e3t_0"] missing = [v for v in required_vars if v not in ds.variables] @@ -237,7 +256,10 @@ def load_dino_data( if "mesh_mask" not in setup: msg = "setup must specify 'mesh_mask'." raise ValueError(msg) - mesh_mask_path = os.path.join(base, str(setup["mesh_mask"])) + + # Resolve mesh mask path + mesh_mask_path = resolve_mesh_mask(str(setup["mesh_mask"]), base) + data["mesh_mask"] = load_mesh_mask(mesh_mask_path) # restart (optional / controlled by mode) diff --git a/src/spinup_evaluation/metrics.py b/src/spinup_evaluation/metrics.py index 14534fc..45c234f 100644 --- a/src/spinup_evaluation/metrics.py +++ b/src/spinup_evaluation/metrics.py @@ -31,7 +31,7 @@ def check_density(density: xarray.DataArray, epsilon: float = 1e-5): """ density = density.where(density != 0) diff = density - density.shift(depth=-1) - bad_prop = (diff > epsilon).mean(dim=["depth", "nav_lat", "nav_lon"]) + bad_prop = (diff > epsilon).mean(dim=["depth", "y", "x"]) return bad_prop @@ -81,9 +81,9 @@ def temperature_500m_30NS_metric( ) # Returning Average Temperature at 500m depth as a numpy scalar - return (t500_30NS * area_500m_30NS).sum( - dim=["nav_lat", "nav_lon"] - ) / area_500m_30NS.sum(dim=["nav_lat", "nav_lon"]) + return (t500_30NS * area_500m_30NS).sum(dim=["y", "x"]) / area_500m_30NS.sum( + dim=["y", "x"] + ) def temperature_BWbox_metric(temperature: xarray.DataArray, file_mask: xarray.Dataset): @@ -140,10 +140,9 @@ def temperature_BWbox_metric(temperature: xarray.DataArray, file_mask: xarray.Da * (abs(temperature.nav_lat) < LAT_BOUND) ) ) - # Returning Average Temperature on Box - return (t_BW * area_BW).sum(dim=["nav_lat", "nav_lon", "depth"]) / area_BW.sum( - dim=["nav_lat", "nav_lon", "depth"] + return (t_BW * area_BW).sum(dim=["y", "x", "depth"]) / area_BW.sum( + dim=["y", "x", "depth"] ) @@ -203,9 +202,10 @@ def temperature_DWbox_metric(temperature: xarray.DataArray, file_mask: xarray.Da ) ) + breakpoint() # Returning Average Temperature on Box - return (t_DW * area_DW).sum(dim=["nav_lat", "nav_lon", "depth"]) / area_DW.sum( - dim=["nav_lat", "nav_lon", "depth"] + return (t_DW * area_DW).sum(dim=["y", "x", "depth"]) / area_DW.sum( + dim=["y", "x", "depth"] ) @@ -235,20 +235,20 @@ def ACC_Drake_metric(uo, file_mask): 1D DataArray (over time_counter), representing the total transport across Drake Passage (in Sv). Length is 1 if only a single time step is present. """ - umask_Drake = file_mask.umask.isel(nav_lon=0).squeeze() + umask_Drake = file_mask.umask.isel(x=0).squeeze() e3u = file_mask.e3u_0.squeeze() e2u = file_mask.e2u.squeeze() # Masking the variables onto the Drake Passage - u_masked = uo.isel(nav_lon=0) * umask_Drake - e3u_masked = e3u.isel(nav_lon=0) * umask_Drake - e2u_masked = e2u.isel(nav_lon=0) * umask_Drake + u_masked = uo.isel(x=0) * umask_Drake + e3u_masked = e3u.isel(x=0) * umask_Drake + e2u_masked = e2u.isel(x=0) * umask_Drake # Multiplying zonal velocity by the sectional areas (e2u*e3u) ubar = u_masked * e3u_masked - flux = (e2u_masked * ubar).sum(dim=["nav_lat", "depth"]) + flux = (e2u_masked * ubar).sum(dim=["y", "depth"]) # Returning Total Transport across Drake passage as a numpy scalar (unit : Sv) return flux / 1e6 @@ -283,22 +283,22 @@ def ACC_Drake_metric_2( 1D DataArray (over time_counter), representing the total transport across Drake Passage (in Sv). Length is 1 if only a single time step is present. """ - umask_Drake = file_mask.umask.isel(nav_lon=0).squeeze() + umask_Drake = file_mask.umask.isel(x=0).squeeze() e3u_0 = file_mask.e3u_0.squeeze() e2u = file_mask.e2u.squeeze() # Recomputing e3u, using ssh to refactor the original e3u_0 cell heights) - ssh_u = (ssh + ssh.roll(nav_lon=-1)) / 2 + ssh_u = (ssh + ssh.roll(x=-1)) / 2 bathy_u = e3u_0.sum(dim="depth") ssumask = umask_Drake[:, 0] e3u = e3u_0 * (1 + ssh_u * ssumask / (bathy_u + 1 - ssumask)) # Masking the variables onto the Drake Passage - u_masked = uo.isel(nav_lon=0) * umask_Drake - e3u_masked = e3u.isel(nav_lon=0) * umask_Drake - e2u_masked = e2u.isel(nav_lon=0) * umask_Drake + u_masked = uo.isel(x=0) * umask_Drake + e3u_masked = e3u.isel(x=0) * umask_Drake + e2u_masked = e2u.isel(x=0) * umask_Drake # Multiplying zonal velocity by the sectional areas (e2u*e3u) ubar = u_masked * e3u_masked - flux = (e2u_masked * ubar).sum(dim=["nav_lat", "depth"]) + flux = (e2u_masked * ubar).sum(dim=["y", "depth"]) # Return Total Transport across Drake passage as a numpy scalar (unit : Sv) return flux / 1e6 @@ -335,7 +335,7 @@ def NASTG_BSF_max( e1v = file_mask.e1v.squeeze() vmask = file_mask.vmask.squeeze() # Updating e3v from e3v_0 and SSH - ssh_v = (ssh + ssh.roll(nav_lat=-1)) / 2 + ssh_v = (ssh + ssh.roll(y=-1)) / 2 bathy_v = e3v_0.sum(dim="depth") ssvmask = vmask.isel(depth=0) e3v = e3v_0 * (1 + ssh_v * ssvmask / (bathy_v + 1 - ssvmask)) @@ -346,12 +346,13 @@ def NASTG_BSF_max( # (BSF=0) V = (vo * e3v).sum(dim="depth") # == "Barotropic Velocity" * Bathymetry BSF = (V * e1v * ssvmask).cumsum( - dim="nav_lon" + dim="x" ) / 1e6 # Integrating from the West, and converting from m³/s to Sv # Selecting 0N-40N window where to search for the maximum, which will correspond to # the center of rotation for the gyre - BSF_NASPG = BSF.where(abs(BSF.nav_lat - 20) < 20) # noqa: PLR2004 + breakpoint() + BSF_NASPG = BSF.where(abs(vo.nav_lat - 20) < 20) # noqa: PLR2004 # Selecting the maximum value of the BSF in the selected window # and return it as a numpy scalar - return BSF_NASPG.max(dim=["nav_lat", "nav_lon"]) + return BSF_NASPG.max(dim=["y", "x"]) diff --git a/src/spinup_evaluation/standardise_inputs.py b/src/spinup_evaluation/standardise_inputs.py index bdf5961..6d581be 100644 --- a/src/spinup_evaluation/standardise_inputs.py +++ b/src/spinup_evaluation/standardise_inputs.py @@ -29,8 +29,8 @@ "depthu", "depthv", ], # Depth can be 'depth', 'nav_lev', or 'deptht' - "latitude": ["nav_lat", "y"], # Latitude can be 'nav_lat' or 'y' - "longitude": ["nav_lon", "x"], # Longitude can be 'nav_lon' or 'x' + # "latitude": ["nav_lat", "y"], # Latitude can be 'nav_lat' or 'y' + # "longitude": ["nav_lon", "x"], # Longitude can be 'nav_lon' or 'x' "ssh": ["sshn", "ssh"], # Sea surface height could be 'sshn' or 'ssh' "time_counter": [ "time_counter", @@ -72,14 +72,23 @@ def standardise(dataset, variable_dict): break ds = ds.rename(rename_map) + breakpoint() - if "x" in ds.dims: - ds = ds.rename({"x": "nav_lon"}) - if "y" in ds.dims: - ds = ds.rename({"y": "nav_lat"}) + # if "x" in ds.dims: + # ds = ds.rename({"x": "nav_lon"}) + # if "y" in ds.dims: + # ds = ds.rename({"y": "nav_lat"}) + + # Promote nav_lat and nav_lon to coordinates if they are not already + # Error is exhibited in DINO restart file + for name in ("nav_lat", "nav_lon"): + if name in ds and name not in ds.coords: + ds = ds.set_coords(name) # zero-copy promotion to coordinate if is_da: new_name = rename_map.get(orig_name, orig_name) return ds[new_name] + # breakpoint() + return ds