diff --git a/src/access_moppy/ocean.py b/src/access_moppy/ocean.py index cb1a346..efd221c 100644 --- a/src/access_moppy/ocean.py +++ b/src/access_moppy/ocean.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -7,6 +8,7 @@ from access_moppy.base import CMIP6_CMORiser from access_moppy.derivations import custom_functions, evaluate_expression from access_moppy.ocean_supergrid import Supergrid +from access_moppy.utilities import calculate_time_bounds from access_moppy.vocabulary_processors import CMIP6Vocabulary @@ -60,10 +62,10 @@ def _get_dim_rename(self): def select_and_process_variables(self): """Select and process variables for the CMOR output.""" input_vars = self.mapping[self.cmor_name]["model_variables"] - time_bnds = ["time_bnds"] + bnds_required = ["time_bnds"] calc = self.mapping[self.cmor_name]["calculation"] - required_vars = set(input_vars + time_bnds) + required_vars = set(input_vars + bnds_required) self.load_dataset(required_vars=required_vars) dim_rename = self._get_dim_rename() @@ -93,14 +95,33 @@ def select_and_process_variables(self): ) self.grid_type, self.symmetric = self.infer_grid_type() - # Drop all other data variables except the CMOR variable - self.ds = self.ds[[self.cmor_name, time_bnds[0]]] + + # Check and calculate time_bnds if missing + if bnds_required[0] not in self.ds: + # Warn user that bounds are missing and will be calculated automatically + warnings.warn( + f"'{bnds_required[0]}' not found in raw data. Automatically calculating bounds for '{bnds_required[0]}' coordinate.", + UserWarning, + stacklevel=2, + ) + try: + calculated_bnds = calculate_time_bounds( + self.ds, time_coord="time", bnds_name="nv" + ) + self.ds[bnds_required[0]] = calculated_bnds + except Exception as e: + raise ValueError( + f"time_bnds is required for CMIP6 compliance but was not found " + f"in the dataset and could not be calculated: {e}" + ) + + self.ds = self.ds[[self.cmor_name, bnds_required[0]]] # Drop unused coordinates used_coords = set() dims = list(self.ds[self.cmor_name].dims) - if time_bnds[0] in self.ds: - dims = list(dict.fromkeys(dims + list(self.ds[time_bnds[0]].dims))) + if bnds_required[0] in self.ds: + dims = list(dict.fromkeys(dims + list(self.ds[bnds_required[0]].dims))) for dim in dims: if dim in self.ds.coords: used_coords.add(dim) diff --git a/tests/mocks/mock_data.py b/tests/mocks/mock_data.py index 96f3968..ffbb276 100644 --- a/tests/mocks/mock_data.py +++ b/tests/mocks/mock_data.py @@ -524,6 +524,175 @@ def create_mock_3d_ocean_dataset( return ds +def create_mock_om2_dataset(nt=12, ny=300, nx=360): + """ + Create a mock ACCESS-OM2 ocean dataset with B-grid coordinates. + Uses xt_ocean/yt_ocean for T-grid points. + """ + import cftime + + xt_ocean = np.linspace(0.5, 359.5, nx) + yt_ocean = np.linspace(-89.5, 89.5, ny) + + time = [ + cftime.DatetimeProlepticGregorian(1850, month + 1, 15) for month in range(nt) + ] + + data = np.random.rand(nt, ny, nx).astype(np.float32) + + # Time bounds + days_per_month = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + base_days = (1850 - 1) * 365 + time_bnds = np.zeros((nt, 2)) + cumulative = base_days + for i in range(nt): + time_bnds[i, 0] = cumulative + time_bnds[i, 1] = cumulative + days_per_month[i % 12] + cumulative += days_per_month[i % 12] + + ds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + data, + { + "long_name": "Conservative temperature", + "units": "K", + "_FillValue": np.float32(-1e20), + "standard_name": "sea_surface_temperature", + }, + ), + "time_bnds": (["time", "nv"], time_bnds), + }, + coords={ + "xt_ocean": ( + "xt_ocean", + xt_ocean, + {"long_name": "tcell longitude", "units": "degrees_E"}, + ), + "yt_ocean": ( + "yt_ocean", + yt_ocean, + {"long_name": "tcell latitude", "units": "degrees_N"}, + ), + "time": ( + "time", + time, + { + "units": "days since 0001-01-01 00:00:00", + "calendar": "proleptic_gregorian", + "bounds": "time_bnds", + }, + ), + "nv": ("nv", [1.0, 2.0]), + }, + attrs={ + "title": "ACCESS-OM2", + "grid_type": "mosaic", + }, + ) + return ds + + +def create_mock_om3_dataset(nt=12, ny=300, nx=360): + """ + Create a mock ACCESS-OM3 ocean dataset with C-grid coordinates. + Uses xh/yh for T-grid (tracer) points. + """ + import cftime + + xh = np.linspace(0.5, 359.5, nx) + yh = np.linspace(-89.5, 89.5, ny) + + time = [ + cftime.DatetimeProlepticGregorian(1850, month + 1, 15) for month in range(nt) + ] + + data = np.random.rand(nt, ny, nx).astype(np.float32) + + ds = xr.Dataset( + data_vars={ + "tos": ( + ["time", "yh", "xh"], + data, + { + "long_name": "Sea Surface Temperature", + "units": "degC", + "_FillValue": np.float32(-1e20), + }, + ), + }, + coords={ + "xh": ( + "xh", + xh, + {"long_name": "h point nominal longitude", "units": "degrees_E"}, + ), + "yh": ( + "yh", + yh, + {"long_name": "h point nominal latitude", "units": "degrees_N"}, + ), + "time": ( + "time", + time, + { + "units": "days since 0001-01-01 00:00:00", + "calendar": "proleptic_gregorian", + }, + ), + }, + attrs={"title": "ACCESS-OM3"}, + ) + return ds + + +def create_mock_supergrid_dataset(ny=7, nx=9): + """ + Create a minimal mock supergrid dataset for testing. + + The supergrid has dimensions (2*ny+1, 2*nx+1) to represent + both cell centers and corners on a staggered grid. + + Parameters + ---------- + ny : int + Number of tracer cells in y direction + nx : int + Number of tracer cells in x direction + + Returns + ------- + xr.Dataset + Mock supergrid with x and y coordinates + """ + # Supergrid dimensions + sg_ny = 2 * ny + 1 + sg_nx = 2 * nx + 1 + + # Create simple regular lat/lon grid for testing + # x ranges from 0 to 360, y from -90 to 90 + x_1d = np.linspace(0, 360, sg_nx) + y_1d = np.linspace(-90, 90, sg_ny) + + x, y = np.meshgrid(x_1d, y_1d) + + ds = xr.Dataset( + { + "x": (["nyp", "nxp"], x), + "y": (["nyp", "nxp"], y), + }, + coords={ + "nyp": np.arange(sg_ny), + "nxp": np.arange(sg_nx), + }, + attrs={ + "title": "Mock Supergrid for Testing", + }, + ) + return ds + + def create_chunked_dataset(chunks=None, **kwargs): """Create a chunked dataset for testing dask operations.""" if chunks is None: diff --git a/tests/unit/test_ocean.py b/tests/unit/test_ocean.py new file mode 100644 index 0000000..e4d24a1 --- /dev/null +++ b/tests/unit/test_ocean.py @@ -0,0 +1,857 @@ +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from access_moppy.base import CMIP6_CMORiser +from access_moppy.ocean import ( + CMIP6_Ocean_CMORiser_OM2, + CMIP6_Ocean_CMORiser_OM3, +) +from tests.mocks.mock_data import ( + create_mock_om2_dataset, + create_mock_om3_dataset, +) + + +class TestCMIP6OceanCMORiserOM2: + """Unit tests for CMIP6_Ocean_CMORiser_OM2 (B-grid).""" + + @pytest.fixture + def mock_vocab(self): + """Mock CMIP6 vocabulary for OM2.""" + vocab = Mock() + vocab.source_id = "ACCESS-OM2" + vocab.variable = {"units": "K", "type": "real"} + vocab._get_nominal_resolution = Mock(return_value="1deg") + vocab.get_required_global_attributes = Mock( + return_value={ + "variable_id": "tos", + "table_id": "Omon", + "source_id": "ACCESS-OM2", + "experiment_id": "historical", + "variant_label": "r1i1p1f1", + "grid_label": "gn", + } + ) + return vocab + + @pytest.fixture + def mock_mapping(self): + """Mock variable mapping for ocean.""" + return { + "tos": { + "model_variables": ["surface_temp"], + "calculation": {"type": "direct"}, + } + } + + @pytest.fixture + def mock_om2_dataset(self): + """Create mock OM2 dataset.""" + return create_mock_om2_dataset(nt=12, ny=30, nx=36) + + @pytest.mark.unit + def test_infer_grid_type_t_grid( + self, mock_vocab, mock_mapping, mock_om2_dataset, temp_dir + ): + """Test that T-grid is inferred from xt_ocean/yt_ocean coordinates.""" + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = mock_om2_dataset + + grid_type, symmetric = cmoriser.infer_grid_type() + + assert grid_type == "T" + assert symmetric is None # MOM5 doesn't use symmetric memory + + @pytest.mark.unit + def test_infer_grid_type_u_grid(self, mock_vocab, mock_mapping, temp_dir): + """Test that U-grid is inferred from xu_ocean/yt_ocean coordinates.""" + ds = xr.Dataset( + coords={ + "xu_ocean": ("xu_ocean", np.arange(10)), + "yt_ocean": ("yt_ocean", np.arange(10)), + } + ) + + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.uo", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds + + grid_type, _ = cmoriser.infer_grid_type() + + assert grid_type == "U" + + @pytest.mark.unit + def test_get_dim_rename_om2(self, mock_vocab, mock_mapping, temp_dir): + """Test dimension renaming for ACCESS-OM2.""" + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + + dim_rename = cmoriser._get_dim_rename() + + assert dim_rename["xt_ocean"] == "i" + assert dim_rename["yt_ocean"] == "j" + assert dim_rename["xu_ocean"] == "i" + assert dim_rename["yu_ocean"] == "j" + assert dim_rename["st_ocean"] == "lev" + + @pytest.mark.unit + def test_arakawa_grid_type(self, mock_vocab, mock_mapping, temp_dir): + """Test that ACCESS-OM2 uses B-grid (Arakawa B).""" + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + + assert cmoriser.arakawa == "B" + + @pytest.mark.unit + def test_time_bnds_loaded_and_preserved( + self, mock_vocab, mock_mapping, mock_om2_dataset, temp_dir + ): + """Test that time_bnds is loaded with other variables and preserved in output.""" + with patch("access_moppy.ocean.Supergrid"): + # Mock load_dataset to avoid file I/O + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = mock_om2_dataset + + # Run the processing + cmoriser.select_and_process_variables() + + # Verify time_bnds is in the output dataset + assert "time_bnds" in cmoriser.ds.data_vars + + # Verify only cmor_name and time_bnds are kept as data variables + assert set(cmoriser.ds.data_vars) == {"tos", "time_bnds"} + + @pytest.mark.unit + def test_time_bnds_dimensions_in_used_coords( + self, mock_vocab, mock_mapping, mock_om2_dataset, temp_dir + ): + """Test that time_bnds dimensions are identified as used coordinates.""" + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = mock_om2_dataset + + # Run the processing + cmoriser.select_and_process_variables() + + # Verify time_bnds dimensions are preserved + assert "time" in cmoriser.ds.coords + assert "nv" in cmoriser.ds.coords # nv is dimension for time_bnds + + # Verify time_bnds has correct dimensions + assert cmoriser.ds["time_bnds"].dims == ("time", "nv") + + @pytest.mark.unit + def test_auto_calculate_time_bnds_when_missing( + self, mock_vocab, mock_mapping, temp_dir + ): + """Test that time_bnds is automatically calculated when missing from source data.""" + # Create dataset WITHOUT time_bnds + ds_no_time_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + { + "long_name": "Sea surface temperature", + "units": "K", + }, + ), + }, + coords={ + "time": pd.date_range("2000-01-15", periods=12, freq="MS"), + "yt_ocean": ("yt_ocean", np.arange(30), {"units": "degrees_N"}), + "xt_ocean": ("xt_ocean", np.arange(36), {"units": "degrees_E"}), + }, + attrs={"title": "ACCESS-OM2", "grid_type": "mosaic"}, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time_bnds + + # Run processing - should automatically calculate time_bnds + cmoriser.select_and_process_variables() + + # Verify time_bnds was created + assert "time_bnds" in cmoriser.ds.data_vars + assert cmoriser.ds["time_bnds"].shape == (12, 2) + + # Verify dimensions + assert cmoriser.ds["time_bnds"].dims == ("time", "nv") + + # Verify nv coordinate exists + assert "nv" in cmoriser.ds.coords + assert len(cmoriser.ds["nv"]) == 2 + + @pytest.mark.unit + def test_required_vars_includes_time_bnds( + self, mock_vocab, mock_mapping, mock_om2_dataset, temp_dir + ): + """Test that time_bnds is included in required_vars during loading.""" + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset") as mock_load: + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = mock_om2_dataset + + # Run processing + cmoriser.select_and_process_variables() + + # Verify load_dataset was called with time_bnds in required_vars + mock_load.assert_called_once() + call_args = mock_load.call_args + required_vars = call_args.kwargs.get("required_vars") or call_args[0][0] + assert "time_bnds" in required_vars + assert "surface_temp" in required_vars # model variable + + @pytest.mark.unit + def test_calculated_time_bnds_values_monthly_first_end( + self, mock_vocab, mock_mapping, temp_dir + ): + """Test that calculated time_bnds has correct month boundaries.""" + # Use proper month-start dates + ds_no_time_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + }, + coords={ + # Generate time centered on mid-month (typical for monthly averages) + "time": pd.date_range("2000-01-01", periods=12, freq="MS") + + pd.Timedelta(days=14), + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + ) + + print(ds_no_time_bnds["time"].values) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time_bnds + + cmoriser.select_and_process_variables() + + time_bnds = cmoriser.ds["time_bnds"] + + # Check first month (January 2000) + # Bounds should be [2000-01-01, 2000-02-01] + print(time_bnds.values) + first_lower = pd.Timestamp(time_bnds[0, 0].values) + first_upper = pd.Timestamp(time_bnds[0, 1].values) + + assert first_lower.year == 2000 + assert first_lower.month == 1 + assert first_lower.day == 1 + + assert first_upper.year == 2000 + assert first_upper.month == 2 + assert first_upper.day == 1 + + # Check last month (December 2000) + last_lower = pd.Timestamp(time_bnds[11, 0].values) + last_upper = pd.Timestamp(time_bnds[11, 1].values) + + assert last_lower.year == 2000 + assert last_lower.month == 12 + assert last_lower.day == 1 + + assert last_upper.year == 2001 + assert last_upper.month == 1 + assert last_upper.day == 1 + + @pytest.mark.unit + def test_calculated_time_bnds_values_monthly_range( + self, mock_vocab, mock_mapping, temp_dir + ): + """Test that calculated time_bnds has correct structure and reasonable values.""" + ds_no_time_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + }, + coords={ + # Monthly time coordinate (mid-month) + "time": pd.date_range("2000-01-15", periods=12, freq="MS") + + pd.Timedelta(days=14), + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + ) + + ds_no_time_bnds["time"].attrs["units"] = "days since 1850-01-01" + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time_bnds + + cmoriser.select_and_process_variables() + + time_bnds = cmoriser.ds["time_bnds"] + + # Verify shape + assert time_bnds.shape == (12, 2) + assert time_bnds.dims == ("time", "nv") + + # For each time step, verify bounds make sense + for i in range(12): + lower = pd.Timestamp(time_bnds[i, 0].values) + upper = pd.Timestamp(time_bnds[i, 1].values) + + # Lower bound should be before upper bound + assert ( + lower < upper + ), f"Lower bound >= upper bound at index {i}: [{lower}, {upper}]" + + # Bounds should span about 1 month (28-31 days) + days_span = (upper - lower).days + assert ( + 28 <= days_span <= 31 + ), f"Unexpected time span {days_span} days at index {i}, expected 28-31 days" + + # Verify all bounds are in year 2000-2001 range (reasonable for test data) + all_bnds = time_bnds.values.flatten() + years = [pd.Timestamp(b).year for b in all_bnds] + assert all( + y in [2000, 2001] for y in years + ), f"Unexpected years in bounds: {set(years)}" + + # Verify bounds have proper attributes + assert "long_name" in time_bnds.attrs + assert "units" in time_bnds.attrs + + @pytest.mark.unit + def test_debug_time_bnds_calculation(self, mock_vocab, mock_mapping, temp_dir): + """Debug test to see what time_bnds are actually calculated.""" + ds_no_time_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + }, + coords={ + "time": pd.date_range("2000-01-15", periods=12, freq="MS") + + pd.Timedelta(days=14), + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time_bnds + + cmoriser.select_and_process_variables() + + # Print what we got + print("\n=== Debug: Time values ===") + for i, t in enumerate(cmoriser.ds["time"].values[:3]): + print(f"time[{i}]: {pd.Timestamp(t)}") + + print("\n=== Debug: Time bounds ===") + for i in range(3): + lower = pd.Timestamp(cmoriser.ds["time_bnds"][i, 0].values) + upper = pd.Timestamp(cmoriser.ds["time_bnds"][i, 1].values) + print(f"time_bnds[{i}]: [{lower}, {upper}]") + + @pytest.mark.unit + def test_existing_time_bnds_not_overwritten( + self, mock_vocab, mock_mapping, temp_dir + ): + """Test that existing time_bnds is NOT overwritten.""" + # Create dataset with existing time_bnds (with special marker values) + time = pd.date_range("2000-01-15", periods=12, freq="MS") + + # Special time_bnds with marker values to verify it's not overwritten + existing_time_bnds = np.zeros((12, 2), dtype="datetime64[ns]") + marker_time = np.datetime64("1999-12-31") # Special marker + existing_time_bnds[:, 0] = marker_time + existing_time_bnds[:, 1] = marker_time + np.timedelta64(1, "D") + + ds_with_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + "time_bnds": ( + ["time", "nv"], + existing_time_bnds, + {"long_name": "time bounds"}, + ), + }, + coords={ + "time": time, + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + "nv": [0, 1], + }, + attrs={"title": "ACCESS-OM2"}, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_with_bnds + + cmoriser.select_and_process_variables() + + # Verify original time_bnds was kept (marker value still there) + assert cmoriser.ds["time_bnds"][0, 0].values == marker_time + assert "time_bnds" in cmoriser.ds.data_vars + + @pytest.mark.unit + def test_time_bnds_attributes(self, mock_vocab, mock_mapping, temp_dir): + """Test that calculated time_bnds has proper attributes.""" + ds_no_time_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + }, + coords={ + "time": ( + "time", + pd.date_range("2000-01-15", periods=12, freq="MS"), + { + "long_name": "time", + "units": "days since 0001-01-01 00:00:00", + "calendar": "PROLEPTIC_GREGORIAN", + }, + ), + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time_bnds + + cmoriser.select_and_process_variables() + + time_bnds = cmoriser.ds["time_bnds"] + + # Check attributes + assert "long_name" in time_bnds.attrs + assert time_bnds.attrs["long_name"] == "time bounds" + assert "units" in time_bnds.attrs + + @pytest.mark.unit + def test_only_tos_and_time_bnds_kept(self, mock_vocab, mock_mapping, temp_dir): + """Test that only CMOR variable and time_bnds are kept in final dataset.""" + # Create dataset with extra variables that should be dropped + ds_with_extras = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + "extra_var1": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36), + ), + "extra_var2": (["yt_ocean", "xt_ocean"], np.random.rand(30, 36)), + }, + coords={ + "time": pd.date_range("2000-01-15", periods=12, freq="MS"), + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + attrs={"title": "ACCESS-OM2"}, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_with_extras + + cmoriser.select_and_process_variables() + + # Only tos and time_bnds should remain + assert set(cmoriser.ds.data_vars) == {"tos", "time_bnds"} + + # Extra variables should be dropped + assert "extra_var1" not in cmoriser.ds + assert "extra_var2" not in cmoriser.ds + assert "surface_temp" not in cmoriser.ds # Original var was renamed + + @pytest.mark.unit + def test_nv_coordinate_preserved(self, mock_vocab, mock_mapping, temp_dir): + """Test that nv coordinate is preserved (needed by time_bnds).""" + ds_no_time_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + }, + coords={ + "time": pd.date_range("2000-01-15", periods=12, freq="MS"), + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + attrs={"title": "ACCESS-OM2"}, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time_bnds + + cmoriser.select_and_process_variables() + + # nv should be in coordinates + assert "nv" in cmoriser.ds.coords + + # time should be in coordinates + assert "time" in cmoriser.ds.coords + + # Spatial coordinates should be preserved (renamed) + assert "j" in cmoriser.ds.coords # Renamed from yt_ocean + assert "i" in cmoriser.ds.coords # Renamed from xt_ocean + + @pytest.mark.unit + def test_error_when_time_missing_and_cannot_calculate( + self, mock_vocab, mock_mapping, temp_dir + ): + """Test that error is raised when time coordinate is missing and time_bnds cannot be calculated.""" + # Create dataset without time coordinate + ds_no_time = xr.Dataset( + data_vars={ + "surface_temp": ( + ["yt_ocean", "xt_ocean"], + np.random.rand(30, 36).astype(np.float32), + ), + }, + coords={ + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time + + # Should raise error because time_bnds cannot be calculated without time + with pytest.raises( + ValueError, match="time_bnds is required.*could not be calculated" + ): + cmoriser.select_and_process_variables() + + @pytest.mark.unit + def test_time_bnds_continuous_coverage(self, mock_vocab, mock_mapping, temp_dir): + """Test that calculated time_bnds provides continuous coverage (no gaps).""" + ds_no_time_bnds = xr.Dataset( + data_vars={ + "surface_temp": ( + ["time", "yt_ocean", "xt_ocean"], + np.random.rand(12, 30, 36).astype(np.float32), + ), + }, + coords={ + "time": pd.date_range("2000-01-15", periods=12, freq="MS"), + "yt_ocean": np.arange(30), + "xt_ocean": np.arange(36), + }, + ) + + with patch("access_moppy.ocean.Supergrid"): + with patch.object(CMIP6_CMORiser, "load_dataset", return_value=None): + cmoriser = CMIP6_Ocean_CMORiser_OM2( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds_no_time_bnds + + cmoriser.select_and_process_variables() + + time_bnds = cmoriser.ds["time_bnds"] + + # Upper bound of month i should equal lower bound of month i+1 + for i in range(len(time_bnds) - 1): + assert ( + time_bnds[i, 1].values == time_bnds[i + 1, 0].values + ), f"Gap in time_bnds between index {i} and {i+1}" + + +class TestCMIP6OceanCMORiserOM3: + """Unit tests for CMIP6_Ocean_CMORiser_OM3 (C-grid).""" + + @pytest.fixture + def mock_vocab(self): + """Mock CMIP6 vocabulary for OM3.""" + vocab = Mock() + vocab.source_id = "ACCESS-OM3" + vocab.variable = {"units": "degC", "type": "real"} + vocab._get_nominal_resolution = Mock(return_value="1deg") + vocab.get_required_global_attributes = Mock( + return_value={ + "variable_id": "tos", + "table_id": "Omon", + "source_id": "ACCESS-OM3", + "experiment_id": "historical", + "variant_label": "r1i1p1f1", + "grid_label": "gn", + } + ) + return vocab + + @pytest.fixture + def mock_mapping(self): + """Mock variable mapping.""" + return { + "tos": { + "model_variables": ["tos"], + "calculation": {"type": "direct"}, + } + } + + @pytest.fixture + def mock_om3_dataset(self): + """Create mock OM3 dataset.""" + return create_mock_om3_dataset(nt=12, ny=30, nx=36) + + @pytest.mark.unit + def test_infer_grid_type_t_grid( + self, mock_vocab, mock_mapping, mock_om3_dataset, temp_dir + ): + """Test that T-grid is inferred from xh/yh coordinates.""" + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM3( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = mock_om3_dataset + + grid_type, symmetric = cmoriser.infer_grid_type() + + assert grid_type == "T" + assert symmetric is True # MOM6 uses symmetric memory + + @pytest.mark.unit + def test_infer_grid_type_u_grid(self, mock_vocab, mock_mapping, temp_dir): + """Test that U-grid is inferred from xq/yh coordinates.""" + ds = xr.Dataset( + coords={ + "xq": ("xq", np.arange(10)), + "yh": ("yh", np.arange(10)), + } + ) + + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM3( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.uo", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds + + grid_type, _ = cmoriser.infer_grid_type() + + assert grid_type == "U" + + @pytest.mark.unit + def test_infer_grid_type_v_grid(self, mock_vocab, mock_mapping, temp_dir): + """Test that V-grid is inferred from xh/yq coordinates.""" + ds = xr.Dataset( + coords={ + "xh": ("xh", np.arange(10)), + "yq": ("yq", np.arange(10)), + } + ) + + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM3( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.vo", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds + + grid_type, _ = cmoriser.infer_grid_type() + + assert grid_type == "V" + + @pytest.mark.unit + def test_infer_grid_type_c_grid(self, mock_vocab, mock_mapping, temp_dir): + """Test that C-grid (corner) is inferred from xq/yq coordinates.""" + ds = xr.Dataset( + coords={ + "xq": ("xq", np.arange(10)), + "yq": ("yq", np.arange(10)), + } + ) + + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM3( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.var", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + cmoriser.ds = ds + + grid_type, _ = cmoriser.infer_grid_type() + + assert grid_type == "C" + + @pytest.mark.unit + def test_get_dim_rename_om3(self, mock_vocab, mock_mapping, temp_dir): + """Test dimension renaming for ACCESS-OM3.""" + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM3( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + + dim_rename = cmoriser._get_dim_rename() + + assert dim_rename["xh"] == "i" + assert dim_rename["yh"] == "j" + assert dim_rename["xq"] == "i" + assert dim_rename["yq"] == "j" + assert dim_rename["zl"] == "lev" + + @pytest.mark.unit + def test_arakawa_grid_type(self, mock_vocab, mock_mapping, temp_dir): + """Test that ACCESS-OM3 uses C-grid (Arakawa C).""" + with patch("access_moppy.ocean.Supergrid"): + cmoriser = CMIP6_Ocean_CMORiser_OM3( + input_paths=["test.nc"], + output_path=str(temp_dir), + compound_name="Omon.tos", + cmip6_vocab=mock_vocab, + variable_mapping=mock_mapping, + ) + + assert cmoriser.arakawa == "C" diff --git a/tests/unit/test_supergrid.py b/tests/unit/test_supergrid.py new file mode 100644 index 0000000..0126d3b --- /dev/null +++ b/tests/unit/test_supergrid.py @@ -0,0 +1,243 @@ +from unittest.mock import patch + +import numpy as np +import pytest + +from access_moppy.ocean_supergrid import Supergrid +from tests.mocks.mock_data import create_mock_supergrid_dataset + + +class TestSupergrid: + """Unit tests for the Supergrid class.""" + + @pytest.fixture + def mock_supergrid_file(self, tmp_path): + """Create a temporary mock supergrid NetCDF file.""" + supergrid_ds = create_mock_supergrid_dataset(ny=7, nx=9) + filepath = tmp_path / "mock_supergrid.nc" + supergrid_ds.to_netcdf(filepath) + return str(filepath) + + @pytest.fixture + def supergrid_instance(self, mock_supergrid_file): + """Create a Supergrid instance with mocked file loading.""" + with patch.object( + Supergrid, "get_supergrid_path", return_value=mock_supergrid_file + ): + sg = Supergrid("100 km") + return sg + + # ==================== Initialization Tests ==================== + + @pytest.mark.unit + def test_init_loads_supergrid_correctly(self, mock_supergrid_file): + """Test that __init__ sets resolution and loads supergrid data.""" + with patch.object( + Supergrid, "get_supergrid_path", return_value=mock_supergrid_file + ): + sg = Supergrid("100 km") + + assert sg.nominal_resolution == "100 km" + assert sg.supergrid is not None + assert "x" in sg.supergrid + assert "y" in sg.supergrid + + # ==================== get_supergrid_path Tests ==================== + + @pytest.mark.unit + def test_get_supergrid_path_on_gadi(self): + """Test that get_supergrid_path returns Gadi path when file exists.""" + gadi_path = "/g/data/xp65/public/apps/access_moppy_data/grids/mom1deg.nc" + + with patch("os.path.exists", return_value=True): + with patch.object(Supergrid, "load_supergrid"): + sg = Supergrid.__new__(Supergrid) + sg.nominal_resolution = "100 km" + path = sg.get_supergrid_path("100 km") + + assert path == gadi_path + + @pytest.mark.unit + @pytest.mark.parametrize( + "resolution,expected_file", + [ + ("100 km", "mom1deg.nc"), + ("25 km", "mom025deg.nc"), + ("10 km", "mom01deg.nc"), + ], + ) + def test_get_supergrid_path_resolution_mapping(self, resolution, expected_file): + """Test that resolutions map to correct filenames.""" + with patch("os.path.exists", return_value=True): + with patch.object(Supergrid, "load_supergrid"): + sg = Supergrid.__new__(Supergrid) + sg.nominal_resolution = resolution + path = sg.get_supergrid_path(resolution) + + assert expected_file in path + + @pytest.mark.unit + @pytest.mark.parametrize( + "resolution,error_match", + [ + ("50 km", "Unknown or unsupported nominal resolution"), + (None, "nominal_resolution must be provided"), + ], + ) + def test_get_supergrid_path_invalid_resolution(self, resolution, error_match): + """Test that invalid resolutions raise appropriate errors.""" + with patch.object(Supergrid, "load_supergrid"): + sg = Supergrid.__new__(Supergrid) + sg.nominal_resolution = resolution + + with pytest.raises(ValueError, match=error_match): + sg.get_supergrid_path(resolution) + + # ==================== load_supergrid Tests ==================== + + @pytest.mark.unit + @pytest.mark.parametrize("cell_type", ["hcell", "qcell", "ucell", "vcell"]) + def test_load_supergrid_creates_cell_arrays(self, supergrid_instance, cell_type): + """Test that load_supergrid creates all cell type arrays with correct structure.""" + sg = supergrid_instance + + # Check centres exist + assert hasattr(sg, f"{cell_type}_centres_x") + assert hasattr(sg, f"{cell_type}_centres_y") + + # Check corners exist with 4 vertices + corners_x = getattr(sg, f"{cell_type}_corners_x") + corners_y = getattr(sg, f"{cell_type}_corners_y") + assert corners_x.shape[-1] == 4 + assert corners_y.shape[-1] == 4 + + @pytest.mark.unit + def test_load_supergrid_cell_dimensions_relationship(self, supergrid_instance): + """Test that q-cell has one more point than h-cell in each direction.""" + sg = supergrid_instance + + h_shape = sg.hcell_centres_x.shape + q_shape = sg.qcell_centres_x.shape + + assert q_shape[0] == h_shape[0] + 1 + assert q_shape[1] == h_shape[1] + 1 + + # ==================== extract_grid Tests - B-grid ==================== + + @pytest.mark.unit + @pytest.mark.parametrize("grid_type", ["T", "U", "V", "C"]) + def test_extract_grid_b_grid_all_types(self, supergrid_instance, grid_type): + """Test extract_grid returns correct structure for all B-grid types.""" + grid_info = supergrid_instance.extract_grid(grid_type=grid_type, arakawa="B") + + # All grid types should return these keys + expected_keys = [ + "latitude", + "longitude", + "vertices_latitude", + "vertices_longitude", + "i", + "j", + "vertices", + ] + for key in expected_keys: + assert key in grid_info + + # Vertices should have 4 corners + assert grid_info["vertices_latitude"].shape[-1] == 4 + + # ==================== extract_grid Tests - C-grid ==================== + + @pytest.mark.unit + @pytest.mark.parametrize("grid_type", ["T", "U", "V", "C"]) + @pytest.mark.parametrize("symmetric", [True, False]) + def test_extract_grid_c_grid_all_types( + self, supergrid_instance, grid_type, symmetric + ): + """Test extract_grid returns correct structure for all C-grid types.""" + grid_info = supergrid_instance.extract_grid( + grid_type=grid_type, arakawa="C", symmetric=symmetric + ) + + assert "latitude" in grid_info + assert "longitude" in grid_info + assert grid_info["vertices_latitude"].shape[-1] == 4 + + @pytest.mark.unit + def test_extract_grid_c_grid_symmetric_vs_asymmetric_dimensions( + self, supergrid_instance + ): + """Test that asymmetric mode has fewer points than symmetric.""" + sg = supergrid_instance + + # U-cell: asymmetric has one fewer column + u_sym = sg.extract_grid(grid_type="U", arakawa="C", symmetric=True) + u_asym = sg.extract_grid(grid_type="U", arakawa="C", symmetric=False) + assert u_asym["longitude"].shape[1] == u_sym["longitude"].shape[1] - 1 + + # V-cell: asymmetric has one fewer row + v_sym = sg.extract_grid(grid_type="V", arakawa="C", symmetric=True) + v_asym = sg.extract_grid(grid_type="V", arakawa="C", symmetric=False) + assert v_asym["latitude"].shape[0] == v_sym["latitude"].shape[0] - 1 + + # ==================== extract_grid Error Handling ==================== + + @pytest.mark.unit + def test_extract_grid_c_grid_requires_symmetric(self, supergrid_instance): + """Test that C-grid requires symmetric parameter.""" + with pytest.raises(ValueError, match="Must specify symmetric"): + supergrid_instance.extract_grid(grid_type="T", arakawa="C", symmetric=None) + + @pytest.mark.unit + def test_extract_grid_unsupported_arakawa(self, supergrid_instance): + """Test that unsupported Arakawa grid raises error.""" + with pytest.raises(ValueError, match="arakawa=.* is not supported"): + supergrid_instance.extract_grid(grid_type="T", arakawa="A") + + @pytest.mark.unit + @pytest.mark.parametrize("arakawa,symmetric", [("B", None), ("C", True)]) + def test_extract_grid_unsupported_grid_type( + self, supergrid_instance, arakawa, symmetric + ): + """Test that unsupported grid type raises error.""" + with pytest.raises(ValueError, match="is not a supported grid_type"): + supergrid_instance.extract_grid( + grid_type="X", arakawa=arakawa, symmetric=symmetric + ) + + # ==================== extract_grid Output Validation ==================== + + @pytest.mark.unit + def test_extract_grid_longitude_normalized(self, supergrid_instance): + """Test that longitude is normalized to [0, 360) range.""" + grid_info = supergrid_instance.extract_grid(grid_type="T", arakawa="B") + + lon = grid_info["longitude"].values + assert np.all(lon >= 0) + assert np.all(lon < 360) + + @pytest.mark.unit + def test_extract_grid_output_structure(self, supergrid_instance): + """Test DataArray dimensions and coordinate values.""" + grid_info = supergrid_instance.extract_grid(grid_type="T", arakawa="B") + + # Check dimensions + assert grid_info["latitude"].dims == ("j", "i") + assert grid_info["longitude"].dims == ("j", "i") + assert grid_info["vertices_latitude"].dims == ("j", "i", "vertices") + + # Check vertices shape matches spatial dims + lat = grid_info["latitude"] + lat_bnds = grid_info["vertices_latitude"] + assert lat_bnds.shape[:2] == lat.shape + + # Check coordinate values are sequential integers + np.testing.assert_array_equal( + grid_info["i"].values, np.arange(len(grid_info["i"])) + ) + np.testing.assert_array_equal( + grid_info["j"].values, np.arange(len(grid_info["j"])) + ) + np.testing.assert_array_equal( + grid_info["vertices"].values, np.array([0, 1, 2, 3]) + ) diff --git a/tests/unit/test_utilities.py b/tests/unit/test_utilities.py new file mode 100644 index 0000000..1350557 --- /dev/null +++ b/tests/unit/test_utilities.py @@ -0,0 +1,416 @@ +""" +Comprehensive tests for calculate_time_bounds and helper functions. +""" + +import os +import tempfile + +import cftime +import numpy as np +import pytest +import xarray as xr + +from access_moppy.utilities import ( + _infer_frequency, + calculate_time_bounds, +) + + +class TestCalculateTimeBoundsErrors: + """Test error handling in calculate_time_bounds.""" + + def test_missing_time_coordinate(self): + """Test error when time coordinate is missing.""" + ds = xr.Dataset(coords={"x": [1, 2, 3]}) + + with pytest.raises(ValueError, match="must contain 'time' coordinate"): + calculate_time_bounds(ds) + + def test_insufficient_time_points(self): + """Test error when less than 2 time points.""" + ds = xr.Dataset(coords={"time": [np.datetime64("2000-01-01")]}) + + with pytest.raises(ValueError, match="Need at least 2 time points"): + calculate_time_bounds(ds) + + +class TestCalculateTimeBoundsMonthly: + """Test monthly frequency time bounds calculation.""" + + def test_monthly_bounds_numpy_datetime64(self): + """Test monthly bounds with numpy datetime64.""" + time = np.array( + ["2000-01-15", "2000-02-15", "2000-03-15", "2000-12-15", "2001-01-15"], + dtype="datetime64[D]", + ) + + ds = xr.Dataset(coords={"time": time}) + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (5, 2) + assert time_bnds.dims == ("time", "nv") + + # Check attributes + assert "long_name" in time_bnds.attrs + assert time_bnds.attrs["long_name"] == "time bounds" + + # Implementation uses midpoint method - bounds bracket the time points + for i in range(len(time)): + assert time_bnds.values[i, 0] < time[i] + assert time_bnds.values[i, 1] > time[i] + + def test_monthly_bounds_cftime(self): + """Test monthly bounds with cftime for wide date ranges.""" + time = xr.cftime_range("0850-01-15", periods=13, freq="MS", calendar="noleap") + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (13, 2) + + # Check first bound - implementation may use different approach + assert time_bnds.values[0, 0].year == 850 + + # Check that bounds bracket the time points + for i in range(len(time)): + assert time_bnds.values[i, 0] <= time.values[i] + assert time_bnds.values[i, 1] >= time.values[i] + + def test_monthly_bounds_year_2200(self): + """Test monthly bounds with year 2200 (edge of typical range).""" + time = xr.cftime_range( + "2200-01-15", periods=12, freq="MS", calendar="proleptic_gregorian" + ) + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + assert time_bnds.shape == (12, 2) + assert time_bnds.values[0, 0].year == 2200 + assert time_bnds.values[-1, 1].year == 2201 + + def test_monthly_bounds_february(self): + """Test monthly bounds handle February correctly.""" + time = np.array(["2000-02-15", "2000-03-15"], dtype="datetime64[D]") + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + # February bounds + assert time_bnds[0, 0] == np.datetime64("2000-02-01") + assert time_bnds[0, 1] == np.datetime64("2000-03-01") + + +class TestCalculateTimeBoundsDaily: + """Test daily frequency time bounds calculation.""" + + def test_daily_bounds_numpy_datetime64(self): + """Test daily bounds with numpy datetime64.""" + time = np.array( + ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04"], + dtype="datetime64[D]", + ) + + ds = xr.Dataset(coords={"time": time}) + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (4, 2) + + # Check bounds for first day + assert time_bnds[0, 0] == np.datetime64("2000-01-01") + assert time_bnds[0, 1] == np.datetime64("2000-01-02") + + # Check bounds for last day + assert time_bnds[3, 0] == np.datetime64("2000-01-04") + assert time_bnds[3, 1] == np.datetime64("2000-01-05") + + def test_daily_bounds_cftime(self): + """Test daily bounds with cftime.""" + time = xr.cftime_range("0850-01-01", periods=5, freq="D", calendar="360_day") + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (5, 2) + + # Check first day + assert time_bnds.values[0, 0].year == 850 + assert time_bnds.values[0, 0].day == 1 + assert time_bnds.values[0, 1].day == 2 + + # Check calendar + assert time_bnds.attrs["calendar"] == "360_day" + + def test_daily_bounds_leap_year(self): + """Test daily bounds around leap day.""" + time = np.array( + ["2000-02-28", "2000-02-29", "2000-03-01"], dtype="datetime64[D]" + ) + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + assert time_bnds[1, 0] == np.datetime64("2000-02-29") + assert time_bnds[1, 1] == np.datetime64("2000-03-01") + + +class TestCalculateTimeBoundsYearly: + """Test yearly frequency time bounds calculation.""" + + def test_yearly_bounds_numpy_datetime64(self): + """Test yearly bounds with numpy datetime64.""" + time = np.array( + ["2000-07-01", "2001-07-01", "2002-07-01"], dtype="datetime64[D]" + ) + + ds = xr.Dataset(coords={"time": time}) + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (3, 2) + + # Check first year + assert time_bnds[0, 0] == np.datetime64("2000-01-01") + assert time_bnds[0, 1] == np.datetime64("2001-01-01") + + # Check second year + assert time_bnds[1, 0] == np.datetime64("2001-01-01") + assert time_bnds[1, 1] == np.datetime64("2002-01-01") + + def test_yearly_bounds_cftime(self): + """Test yearly bounds with cftime.""" + time = xr.cftime_range("0850-07-01", periods=3, freq="YE", calendar="noleap") + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (3, 2) + + # Check bounds + assert time_bnds.values[0, 0].year == 850 + assert time_bnds.values[0, 0].month == 1 + assert time_bnds.values[0, 1].year == 851 + + +class TestCalculateTimeBoundsIrregular: + """Test irregular/midpoint time bounds calculation.""" + + def test_midpoint_bounds_numpy(self): + """Test midpoint bounds for irregular data with numpy datetime64.""" + # Irregular spacing: 10, 5, 15 days + time = np.array( + ["2000-01-01", "2000-01-11", "2000-01-16", "2000-01-31"], + dtype="datetime64[D]", + ) + + ds = xr.Dataset(coords={"time": time}) + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (4, 2) + + # First point: extrapolate backward + # dt_first = 10 days, so lower bound = 2000-01-01 - 5 days + assert time_bnds[0, 0] == np.datetime64("1999-12-27") + + # Middle point bounds should be midpoints + # time[1] = 2000-01-11, midpoint to prev = 2000-01-06, midpoint to next = 2000-01-13.5 + assert time_bnds[1, 0] == np.datetime64("2000-01-06") + + # Last point: extrapolate forward + # dt_last = 15 days, upper bound = 2000-01-31 + 7.5 days + assert time_bnds[3, 1] == np.datetime64("2000-02-07T12:00:00") + + def test_midpoint_bounds_cftime(self): + """Test midpoint bounds for irregular data with cftime.""" + # Create irregular time points + time_vals = [ + cftime.DatetimeNoLeap(850, 1, 1), + cftime.DatetimeNoLeap(850, 1, 11), + cftime.DatetimeNoLeap(850, 1, 16), + cftime.DatetimeNoLeap(850, 1, 31), + ] + time = xr.DataArray(time_vals, dims=["time"], name="time") + time.attrs["calendar"] = "noleap" + + ds = xr.Dataset(coords={"time": time}) + time_bnds = calculate_time_bounds(ds) + + # Check shape + assert time_bnds.shape == (4, 2) + + # Check that bounds are cftime objects + assert hasattr(time_bnds.values[0, 0], "calendar") + + +class TestInferFrequency: + """Test the _infer_frequency helper function.""" + + def test_infer_frequency_monthly(self): + """Test frequency inference for monthly data.""" + time_values = np.array( + ["2000-01-15", "2000-02-15", "2000-03-15"], dtype="datetime64[D]" + ) + + freq = _infer_frequency(time_values) + assert freq == "monthly" + + def test_infer_frequency_monthly_cftime(self): + """Test frequency inference for monthly cftime data.""" + time_values = xr.cftime_range("2000-01", periods=12, freq="MS").values + + freq = _infer_frequency(time_values) + assert freq == "monthly" + + def test_infer_frequency_daily(self): + """Test frequency inference for daily data.""" + time_values = np.array( + ["2000-01-01", "2000-01-02", "2000-01-03"], dtype="datetime64[D]" + ) + + freq = _infer_frequency(time_values) + assert freq == "daily" + + def test_infer_frequency_yearly(self): + """Test frequency inference for yearly data.""" + time_values = np.array( + ["2000-01-01", "2001-01-01", "2002-01-01"], dtype="datetime64[D]" + ) + + freq = _infer_frequency(time_values) + assert freq == "yearly" + + def test_infer_frequency_irregular(self): + """Test frequency inference for irregular data.""" + time_values = np.array( + ["2000-01-01", "2000-01-05", "2000-01-12"], dtype="datetime64[D]" + ) + + freq = _infer_frequency(time_values) + assert freq == "irregular" + + def test_infer_frequency_single_point(self): + """Test frequency inference with single time point.""" + time_values = np.array(["2000-01-01"], dtype="datetime64[D]") + + freq = _infer_frequency(time_values) + assert freq is None + + +class TestCalculateTimeBoundsEdgeCases: + """Test edge cases and special scenarios.""" + + def test_preserves_units(self): + """Test that time bounds preserve time units attribute.""" + time = np.array(["2000-01-15", "2000-02-15"], dtype="datetime64[D]") + ds = xr.Dataset(coords={"time": time}) + ds["time"].attrs["units"] = "days since 1850-01-01" + + time_bnds = calculate_time_bounds(ds) + + assert time_bnds.attrs["units"] == "days since 1850-01-01" + + def test_default_units(self): + """Test that no units are added when time doesn't have units.""" + time = np.array(["2000-01-15", "2000-02-15"], dtype="datetime64[D]") + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + assert "long_name" in time_bnds.attrs + + def test_nv_coordinate(self): + """Test that nv coordinate is created correctly.""" + time = np.array(["2000-01-15", "2000-02-15"], dtype="datetime64[D]") + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + assert "nv" in time_bnds.coords + np.testing.assert_array_equal(time_bnds.coords["nv"].values, [0, 1]) + + def test_different_calendars(self): + """Test with different calendar types.""" + calendars = ["noleap", "360_day", "gregorian", "proleptic_gregorian"] + # Note: 'gregorian' and 'standard' are synonyms in cftime + expected_calendars = { + "noleap": "noleap", + "360_day": "360_day", + "gregorian": "standard", # cftime converts gregorian to standard + "proleptic_gregorian": "proleptic_gregorian", + } + + for calendar in calendars: + time = xr.cftime_range("2000-01", periods=3, freq="MS", calendar=calendar) + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + assert time_bnds.shape == (3, 2) + # Check calendar if present - allow for cftime aliases + if "calendar" in time_bnds.attrs: + assert time_bnds.attrs["calendar"] in [ + calendar, + expected_calendars.get(calendar, calendar), + ] + + def test_long_time_series(self): + """Test with a long time series (performance check).""" + time = xr.cftime_range( + "0000-01", periods=100, freq="MS", calendar="proleptic_gregorian" + ) + ds = xr.Dataset(coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + assert time_bnds.shape == (100, 2) + assert time_bnds.values[0, 0].year == 0 + assert time_bnds.values[-1, 1].year == 8 # ~8 years later + + +class TestCalculateTimeBoundsIntegration: + """Integration tests for time bounds.""" + + def test_time_bounds_roundtrip(self): + """Test that time bounds can be written and read from netCDF.""" + time = xr.cftime_range("2000-01", periods=12, freq="MS") + ds = xr.Dataset({"data": (["time"], np.random.rand(12))}, coords={"time": time}) + + time_bnds = calculate_time_bounds(ds) + + # Remove encoding-related attributes to avoid xarray conflicts + time_bnds_clean = time_bnds.copy() + attrs_to_keep = {} + for key, value in time_bnds.attrs.items(): + if key not in ["units", "calendar"]: # These are handled by xarray encoding + attrs_to_keep[key] = value + time_bnds_clean.attrs = attrs_to_keep + + ds["time_bnds"] = time_bnds_clean + + # Write to temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix=".nc") as tmp: + tmp_path = tmp.name + + try: + ds.to_netcdf(tmp_path) + + # Read back + ds_read = xr.open_dataset(tmp_path, decode_times=True) + + assert "time_bnds" in ds_read + assert ds_read["time_bnds"].shape == (12, 2) + + ds_read.close() + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--cov=your_module", "--cov-report=html"])