Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/access_moppy/ocean.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

Expand All @@ -7,6 +8,7 @@
from access_moppy.base import CMIP6_CMORiser
from access_moppy.derivations import custom_functions, evaluate_expression
from access_moppy.ocean_supergrid import Supergrid
from access_moppy.utilities import calculate_time_bounds
from access_moppy.vocabulary_processors import CMIP6Vocabulary


Expand Down Expand Up @@ -60,10 +62,10 @@ def _get_dim_rename(self):
def select_and_process_variables(self):
"""Select and process variables for the CMOR output."""
input_vars = self.mapping[self.cmor_name]["model_variables"]
time_bnds = ["time_bnds"]
bnds_required = ["time_bnds"]
calc = self.mapping[self.cmor_name]["calculation"]

required_vars = set(input_vars + time_bnds)
required_vars = set(input_vars + bnds_required)
self.load_dataset(required_vars=required_vars)

dim_rename = self._get_dim_rename()
Expand Down Expand Up @@ -93,14 +95,33 @@ def select_and_process_variables(self):
)

self.grid_type, self.symmetric = self.infer_grid_type()
# Drop all other data variables except the CMOR variable
self.ds = self.ds[[self.cmor_name, time_bnds[0]]]

# Check and calculate time_bnds if missing
if bnds_required[0] not in self.ds:
# Warn user that bounds are missing and will be calculated automatically
warnings.warn(
f"'{bnds_required[0]}' not found in raw data. Automatically calculating bounds for '{bnds_required[0]}' coordinate.",
UserWarning,
stacklevel=2,
)
try:
calculated_bnds = calculate_time_bounds(
self.ds, time_coord="time", bnds_name="nv"
)
self.ds[bnds_required[0]] = calculated_bnds
except Exception as e:
raise ValueError(
f"time_bnds is required for CMIP6 compliance but was not found "
f"in the dataset and could not be calculated: {e}"
)

self.ds = self.ds[[self.cmor_name, bnds_required[0]]]

# Drop unused coordinates
used_coords = set()
dims = list(self.ds[self.cmor_name].dims)
if time_bnds[0] in self.ds:
dims = list(dict.fromkeys(dims + list(self.ds[time_bnds[0]].dims)))
if bnds_required[0] in self.ds:
dims = list(dict.fromkeys(dims + list(self.ds[bnds_required[0]].dims)))
for dim in dims:
if dim in self.ds.coords:
used_coords.add(dim)
Expand Down
169 changes: 169 additions & 0 deletions tests/mocks/mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,175 @@ def create_mock_3d_ocean_dataset(
return ds


def create_mock_om2_dataset(nt=12, ny=300, nx=360):
"""
Create a mock ACCESS-OM2 ocean dataset with B-grid coordinates.
Uses xt_ocean/yt_ocean for T-grid points.
"""
import cftime

xt_ocean = np.linspace(0.5, 359.5, nx)
yt_ocean = np.linspace(-89.5, 89.5, ny)

time = [
cftime.DatetimeProlepticGregorian(1850, month + 1, 15) for month in range(nt)
]

data = np.random.rand(nt, ny, nx).astype(np.float32)

# Time bounds
days_per_month = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
base_days = (1850 - 1) * 365
time_bnds = np.zeros((nt, 2))
cumulative = base_days
for i in range(nt):
time_bnds[i, 0] = cumulative
time_bnds[i, 1] = cumulative + days_per_month[i % 12]
cumulative += days_per_month[i % 12]

ds = xr.Dataset(
data_vars={
"surface_temp": (
["time", "yt_ocean", "xt_ocean"],
data,
{
"long_name": "Conservative temperature",
"units": "K",
"_FillValue": np.float32(-1e20),
"standard_name": "sea_surface_temperature",
},
),
"time_bnds": (["time", "nv"], time_bnds),
},
coords={
"xt_ocean": (
"xt_ocean",
xt_ocean,
{"long_name": "tcell longitude", "units": "degrees_E"},
),
"yt_ocean": (
"yt_ocean",
yt_ocean,
{"long_name": "tcell latitude", "units": "degrees_N"},
),
"time": (
"time",
time,
{
"units": "days since 0001-01-01 00:00:00",
"calendar": "proleptic_gregorian",
"bounds": "time_bnds",
},
),
"nv": ("nv", [1.0, 2.0]),
},
attrs={
"title": "ACCESS-OM2",
"grid_type": "mosaic",
},
)
return ds


def create_mock_om3_dataset(nt=12, ny=300, nx=360):
"""
Create a mock ACCESS-OM3 ocean dataset with C-grid coordinates.
Uses xh/yh for T-grid (tracer) points.
"""
import cftime

xh = np.linspace(0.5, 359.5, nx)
yh = np.linspace(-89.5, 89.5, ny)

time = [
cftime.DatetimeProlepticGregorian(1850, month + 1, 15) for month in range(nt)
]

data = np.random.rand(nt, ny, nx).astype(np.float32)

ds = xr.Dataset(
data_vars={
"tos": (
["time", "yh", "xh"],
data,
{
"long_name": "Sea Surface Temperature",
"units": "degC",
"_FillValue": np.float32(-1e20),
},
),
},
coords={
"xh": (
"xh",
xh,
{"long_name": "h point nominal longitude", "units": "degrees_E"},
),
"yh": (
"yh",
yh,
{"long_name": "h point nominal latitude", "units": "degrees_N"},
),
"time": (
"time",
time,
{
"units": "days since 0001-01-01 00:00:00",
"calendar": "proleptic_gregorian",
},
),
},
attrs={"title": "ACCESS-OM3"},
)
return ds


def create_mock_supergrid_dataset(ny=7, nx=9):
"""
Create a minimal mock supergrid dataset for testing.

The supergrid has dimensions (2*ny+1, 2*nx+1) to represent
both cell centers and corners on a staggered grid.

Parameters
----------
ny : int
Number of tracer cells in y direction
nx : int
Number of tracer cells in x direction

Returns
-------
xr.Dataset
Mock supergrid with x and y coordinates
"""
# Supergrid dimensions
sg_ny = 2 * ny + 1
sg_nx = 2 * nx + 1

# Create simple regular lat/lon grid for testing
# x ranges from 0 to 360, y from -90 to 90
x_1d = np.linspace(0, 360, sg_nx)
y_1d = np.linspace(-90, 90, sg_ny)

x, y = np.meshgrid(x_1d, y_1d)

ds = xr.Dataset(
{
"x": (["nyp", "nxp"], x),
"y": (["nyp", "nxp"], y),
},
coords={
"nyp": np.arange(sg_ny),
"nxp": np.arange(sg_nx),
},
attrs={
"title": "Mock Supergrid for Testing",
},
)
return ds


def create_chunked_dataset(chunks=None, **kwargs):
"""Create a chunked dataset for testing dask operations."""
if chunks is None:
Expand Down
Loading