Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spinup_evaluation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def apply_metrics_restart(data: xr.Dataset, mask: xr.Dataset) -> Dict[str, Any]:
),
# Name suggests temperature, but kept as-is from original code.
"temperature_DWbox_metric": lambda d: temperature_DWbox_metric(
d["velocity_u"][0], mask
d["temperature"][0], mask
),
"ACC_Drake_metric": lambda d: ACC_Drake_metric(d["velocity_u"][0], mask),
"ACC_Drake_metric_2": lambda d: ACC_Drake_metric_2(
Expand Down Expand Up @@ -187,7 +187,7 @@ def apply_metrics_output(
),
# Name suggests temperature, but kept aligned with original mapping:
"temperature_DWbox_metric": lambda: temperature_DWbox_metric(
grid_output["velocity_u"], mask
grid_output["temperature"], mask
),
"ACC_Drake_metric": lambda: ACC_Drake_metric(grid_output["velocity_u"], mask),
"ACC_Drake_metric_2": lambda: ACC_Drake_metric_2(
Expand Down
26 changes: 24 additions & 2 deletions src/spinup_evaluation/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import glob
import os
from pathlib import Path
from typing import Dict, Mapping, Optional, Union

import xarray as xr
Expand Down Expand Up @@ -120,8 +121,26 @@ def _normalise_var_specs(
return normalised


def load_mesh_mask(path: str) -> xr.Dataset:
def resolve_mesh_mask(mesh_mask: str, sim_path: str) -> Path:
"""Resolve the mesh mask path, handling absolute and relative paths."""
p = Path(mesh_mask)
candidate = p if p.is_absolute() else Path(sim_path) / mesh_mask
if not candidate.exists():
hint = (
"Set `mesh_mask` to an absolute path in your YAML, "
"or ensure it exists under --sim-path."
)
msg = f"Mesh mask file not found: {candidate}. {hint}"
raise FileNotFoundError(msg)

return candidate


def load_mesh_mask(path: Path) -> xr.Dataset:
"""Load the NEMO mesh mask file and validate required fields."""
if not path.exists():
msg = f"Mesh mask file not found: {path}"
raise FileNotFoundError(msg)
ds = xr.open_dataset(path)
required_vars = ["tmask", "e1t", "e2t", "e3t_0"]
missing = [v for v in required_vars if v not in ds.variables]
Expand Down Expand Up @@ -237,7 +256,10 @@ def load_dino_data(
if "mesh_mask" not in setup:
msg = "setup must specify 'mesh_mask'."
raise ValueError(msg)
mesh_mask_path = os.path.join(base, str(setup["mesh_mask"]))

# Resolve mesh mask path
mesh_mask_path = resolve_mesh_mask(str(setup["mesh_mask"]), base)

data["mesh_mask"] = load_mesh_mask(mesh_mask_path)

# restart (optional / controlled by mode)
Expand Down
49 changes: 25 additions & 24 deletions src/spinup_evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def check_density(density: xarray.DataArray, epsilon: float = 1e-5):
"""
density = density.where(density != 0)
diff = density - density.shift(depth=-1)
bad_prop = (diff > epsilon).mean(dim=["depth", "nav_lat", "nav_lon"])
bad_prop = (diff > epsilon).mean(dim=["depth", "y", "x"])
return bad_prop


Expand Down Expand Up @@ -81,9 +81,9 @@ def temperature_500m_30NS_metric(
)

# Returning Average Temperature at 500m depth as a numpy scalar
return (t500_30NS * area_500m_30NS).sum(
dim=["nav_lat", "nav_lon"]
) / area_500m_30NS.sum(dim=["nav_lat", "nav_lon"])
return (t500_30NS * area_500m_30NS).sum(dim=["y", "x"]) / area_500m_30NS.sum(
dim=["y", "x"]
)


def temperature_BWbox_metric(temperature: xarray.DataArray, file_mask: xarray.Dataset):
Expand Down Expand Up @@ -140,10 +140,9 @@ def temperature_BWbox_metric(temperature: xarray.DataArray, file_mask: xarray.Da
* (abs(temperature.nav_lat) < LAT_BOUND)
)
)

# Returning Average Temperature on Box
return (t_BW * area_BW).sum(dim=["nav_lat", "nav_lon", "depth"]) / area_BW.sum(
dim=["nav_lat", "nav_lon", "depth"]
return (t_BW * area_BW).sum(dim=["y", "x", "depth"]) / area_BW.sum(
dim=["y", "x", "depth"]
)


Expand Down Expand Up @@ -203,9 +202,10 @@ def temperature_DWbox_metric(temperature: xarray.DataArray, file_mask: xarray.Da
)
)

breakpoint()
# Returning Average Temperature on Box
return (t_DW * area_DW).sum(dim=["nav_lat", "nav_lon", "depth"]) / area_DW.sum(
dim=["nav_lat", "nav_lon", "depth"]
return (t_DW * area_DW).sum(dim=["y", "x", "depth"]) / area_DW.sum(
dim=["y", "x", "depth"]
)


Expand Down Expand Up @@ -235,20 +235,20 @@ def ACC_Drake_metric(uo, file_mask):
1D DataArray (over time_counter), representing the total transport across Drake
Passage (in Sv). Length is 1 if only a single time step is present.
"""
umask_Drake = file_mask.umask.isel(nav_lon=0).squeeze()
umask_Drake = file_mask.umask.isel(x=0).squeeze()
e3u = file_mask.e3u_0.squeeze()
e2u = file_mask.e2u.squeeze()

# Masking the variables onto the Drake Passage

u_masked = uo.isel(nav_lon=0) * umask_Drake
e3u_masked = e3u.isel(nav_lon=0) * umask_Drake
e2u_masked = e2u.isel(nav_lon=0) * umask_Drake
u_masked = uo.isel(x=0) * umask_Drake
e3u_masked = e3u.isel(x=0) * umask_Drake
e2u_masked = e2u.isel(x=0) * umask_Drake

# Multiplying zonal velocity by the sectional areas (e2u*e3u)

ubar = u_masked * e3u_masked
flux = (e2u_masked * ubar).sum(dim=["nav_lat", "depth"])
flux = (e2u_masked * ubar).sum(dim=["y", "depth"])
# Returning Total Transport across Drake passage as a numpy scalar (unit : Sv)
return flux / 1e6

Expand Down Expand Up @@ -283,22 +283,22 @@ def ACC_Drake_metric_2(
1D DataArray (over time_counter), representing the total transport across Drake
Passage (in Sv). Length is 1 if only a single time step is present.
"""
umask_Drake = file_mask.umask.isel(nav_lon=0).squeeze()
umask_Drake = file_mask.umask.isel(x=0).squeeze()
e3u_0 = file_mask.e3u_0.squeeze()
e2u = file_mask.e2u.squeeze()

# Recomputing e3u, using ssh to refactor the original e3u_0 cell heights)
ssh_u = (ssh + ssh.roll(nav_lon=-1)) / 2
ssh_u = (ssh + ssh.roll(x=-1)) / 2
bathy_u = e3u_0.sum(dim="depth")
ssumask = umask_Drake[:, 0]
e3u = e3u_0 * (1 + ssh_u * ssumask / (bathy_u + 1 - ssumask))
# Masking the variables onto the Drake Passage
u_masked = uo.isel(nav_lon=0) * umask_Drake
e3u_masked = e3u.isel(nav_lon=0) * umask_Drake
e2u_masked = e2u.isel(nav_lon=0) * umask_Drake
u_masked = uo.isel(x=0) * umask_Drake
e3u_masked = e3u.isel(x=0) * umask_Drake
e2u_masked = e2u.isel(x=0) * umask_Drake
# Multiplying zonal velocity by the sectional areas (e2u*e3u)
ubar = u_masked * e3u_masked
flux = (e2u_masked * ubar).sum(dim=["nav_lat", "depth"])
flux = (e2u_masked * ubar).sum(dim=["y", "depth"])
# Return Total Transport across Drake passage as a numpy scalar (unit : Sv)
return flux / 1e6

Expand Down Expand Up @@ -335,7 +335,7 @@ def NASTG_BSF_max(
e1v = file_mask.e1v.squeeze()
vmask = file_mask.vmask.squeeze()
# Updating e3v from e3v_0 and SSH
ssh_v = (ssh + ssh.roll(nav_lat=-1)) / 2
ssh_v = (ssh + ssh.roll(y=-1)) / 2
bathy_v = e3v_0.sum(dim="depth")
ssvmask = vmask.isel(depth=0)
e3v = e3v_0 * (1 + ssh_v * ssvmask / (bathy_v + 1 - ssvmask))
Expand All @@ -346,12 +346,13 @@ def NASTG_BSF_max(
# (BSF=0)
V = (vo * e3v).sum(dim="depth") # == "Barotropic Velocity" * Bathymetry
BSF = (V * e1v * ssvmask).cumsum(
dim="nav_lon"
dim="x"
) / 1e6 # Integrating from the West, and converting from m³/s to Sv
# Selecting 0N-40N window where to search for the maximum, which will correspond to
# the center of rotation for the gyre
BSF_NASPG = BSF.where(abs(BSF.nav_lat - 20) < 20) # noqa: PLR2004
breakpoint()
BSF_NASPG = BSF.where(abs(vo.nav_lat - 20) < 20) # noqa: PLR2004

# Selecting the maximum value of the BSF in the selected window
# and return it as a numpy scalar
return BSF_NASPG.max(dim=["nav_lat", "nav_lon"])
return BSF_NASPG.max(dim=["y", "x"])
21 changes: 15 additions & 6 deletions src/spinup_evaluation/standardise_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
"depthu",
"depthv",
], # Depth can be 'depth', 'nav_lev', or 'deptht'
"latitude": ["nav_lat", "y"], # Latitude can be 'nav_lat' or 'y'
"longitude": ["nav_lon", "x"], # Longitude can be 'nav_lon' or 'x'
# "latitude": ["nav_lat", "y"], # Latitude can be 'nav_lat' or 'y'
# "longitude": ["nav_lon", "x"], # Longitude can be 'nav_lon' or 'x'
"ssh": ["sshn", "ssh"], # Sea surface height could be 'sshn' or 'ssh'
"time_counter": [
"time_counter",
Expand Down Expand Up @@ -72,14 +72,23 @@ def standardise(dataset, variable_dict):
break

ds = ds.rename(rename_map)
breakpoint()

if "x" in ds.dims:
ds = ds.rename({"x": "nav_lon"})
if "y" in ds.dims:
ds = ds.rename({"y": "nav_lat"})
# if "x" in ds.dims:
# ds = ds.rename({"x": "nav_lon"})
# if "y" in ds.dims:
# ds = ds.rename({"y": "nav_lat"})

# Promote nav_lat and nav_lon to coordinates if they are not already
# Error is exhibited in DINO restart file
for name in ("nav_lat", "nav_lon"):
if name in ds and name not in ds.coords:
ds = ds.set_coords(name) # zero-copy promotion to coordinate

if is_da:
new_name = rename_map.get(orig_name, orig_name)
return ds[new_name]

# breakpoint()

return ds
Loading