Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/access_moppy/atmosphere.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import warnings

import numpy as np
import xarray as xr

from access_moppy.base import CMIP6_CMORiser
from access_moppy.derivations import custom_functions, evaluate_expression
from access_moppy.utilities import (
calculate_latitude_bounds,
calculate_longitude_bounds,
calculate_time_bounds,
)


class CMIP6_Atmosphere_CMORiser(CMIP6_CMORiser):
Expand Down Expand Up @@ -38,6 +45,56 @@ def select_and_process_variables(self):
self.load_dataset(required_vars=required_vars)
self.sort_time_dimension()

# Calculate missing bounds variables
for bnds_var in bnds_required:
if bnds_var not in self.ds.data_vars and bnds_var not in self.ds.coords:
# Extract coordinate name by removing "_bnds" suffix
coord_name = bnds_var.replace("_bnds", "")

if coord_name not in self.ds.coords:
raise ValueError(
f"Cannot calculate {bnds_var}: coordinate '{coord_name}' not found in dataset"
)

# Warn user that bounds are missing and will be calculated automatically
warnings.warn(
f"'{bnds_var}' not found in raw data. Automatically calculating bounds for '{coord_name}' coordinate.",
UserWarning,
stacklevel=2,
)

# Determine which calculation function to use based on coordinate name
if coord_name in ["time", "t"]:
# Calculate time bounds - atmosphere uses "bnds"
self.ds[bnds_var] = calculate_time_bounds(
self.ds,
time_coord=coord_name,
bnds_name="bnds", # Atmosphere uses "bnds"
)
self.ds[coord_name].attrs["bounds"] = bnds_var

elif coord_name in ["lat", "latitude", "y"]:
# Calculate latitude bounds - use "bnds" for atmosphere data
self.ds[bnds_var] = calculate_latitude_bounds(
self.ds, coord_name, bnds_name="bnds"
)
self.ds[coord_name].attrs["bounds"] = bnds_var

elif coord_name in ["lon", "longitude", "x"]:
# Calculate longitude bounds - use "bnds" for atmosphere data
self.ds[bnds_var] = calculate_longitude_bounds(
self.ds, coord_name, bnds_name="bnds"
)
self.ds[coord_name].attrs["bounds"] = bnds_var

else:
# For other coordinates, we could add more handlers or skip
warnings.warn(
f"No automatic calculation available for '{bnds_var}'. This may cause CMIP6 compliance issues.",
UserWarning,
stacklevel=2,
)

# Handle the calculation type
if calc["type"] == "direct":
# If the calculation is direct, just rename the variable
Expand Down
238 changes: 195 additions & 43 deletions src/access_moppy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import cftime
import dask.array as da
import netCDF4 as nc
import psutil
import xarray as xr
from cftime import num2date
from cftime import date2num, num2date
from dask.distributed import get_client

from access_moppy.utilities import (
Expand Down Expand Up @@ -172,7 +173,8 @@ class CMIP6_CMORiser:

def __init__(
self,
input_paths: Union[str, List[str]],
input_data: Optional[Union[str, List[str], xr.Dataset, xr.DataArray]] = None,
*,
output_path: str,
cmip6_vocab: Any,
variable_mapping: Dict[str, Any],
Expand All @@ -185,10 +187,44 @@ def __init__(
chunk_size_mb: float = 4.0,
enable_compression: bool = True,
compression_level: int = 4,
# Backward compatibility
input_paths: Optional[Union[str, List[str]]] = None,
):
self.input_paths = (
input_paths if isinstance(input_paths, list) else [input_paths]
)
# Handle backward compatibility and validation
if input_paths is not None and input_data is None:
warnings.warn(
"The 'input_paths' parameter is deprecated. Use 'input_data' instead.",
DeprecationWarning,
stacklevel=2,
)
input_data = input_paths
elif input_paths is not None and input_data is not None:
raise ValueError(
"Cannot specify both 'input_data' and 'input_paths'. Use 'input_data'."
)
elif input_paths is None and input_data is None:
raise ValueError("Must specify either 'input_data' or 'input_paths'.")

# Determine input type and handle appropriately
self.input_is_xarray = isinstance(input_data, (xr.Dataset, xr.DataArray))

if self.input_is_xarray:
# For xarray inputs, store the dataset directly
if isinstance(input_data, xr.DataArray):
self.input_dataset = input_data.to_dataset()
else:
self.input_dataset = input_data
self.input_paths = [] # Empty list for compatibility
else:
# For file paths, store as before
self.input_paths = (
input_data
if isinstance(input_data, list)
else [input_data]
if input_data
else []
)
self.input_dataset = None
self.output_path = output_path
# Extract cmor_name from compound_name
_, self.cmor_name = compound_name.split(".")
Expand Down Expand Up @@ -227,53 +263,102 @@ def __repr__(self):

def load_dataset(self, required_vars: Optional[List[str]] = None):
"""
Load dataset from input files with optional frequency validation.
Load dataset from input files or use provided xarray objects with optional frequency validation.

Args:
required_vars: Optional list of required variables to extract
"""

def _preprocess(ds):
return ds[list(required_vars & set(ds.data_vars))]

# Validate frequency consistency and CMIP6 compatibility before concatenation
if self.validate_frequency and len(self.input_paths) > 0:
try:
# Enhanced validation with CMIP6 frequency compatibility
detected_freq, resampling_required = (
validate_cmip6_frequency_compatibility(
self.input_paths,
self.compound_name,
time_coord="time",
interactive=True,
# If input is already an xarray object, use it directly
if self.input_is_xarray:
self.ds = (
self.input_dataset.copy()
) # Make a copy to avoid modifying original

# SAFEGUARD: Convert cftime coordinates to numeric if present
self.ds = self._ensure_numeric_time_coordinates(self.ds)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a safeguard to handle cases where the input data use cftime in time-related coordinates and variables.


# Apply variable filtering if required_vars is specified
if required_vars:
available_vars = set(self.ds.data_vars) | set(self.ds.coords)
vars_to_keep = set(required_vars) & available_vars
if vars_to_keep != set(required_vars):
missing_vars = set(required_vars) - available_vars
warnings.warn(
f"Some required variables not found in dataset: {missing_vars}. "
f"Available variables: {available_vars}"
)
)
if resampling_required:

# Keep only required data variables
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped redundant coordinates and dimensions to prevent them from affecting other parts of the workflow. These steps were previously handled implicitly by xr.open_mfdataset() and now need to be handled explicitly.

data_vars_to_keep = vars_to_keep & set(self.ds.data_vars)

# Collect dimensions used by these data variables
used_dims = set()
for var in data_vars_to_keep:
used_dims.update(self.ds[var].dims)

# Exclude auxiliary time dimension
if "time_0" in used_dims:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time_0 is a special coords and need to be handled specificly.

self.ds = self.ds.isel(time_0=0, drop=True)
used_dims.remove("time_0")

# Step 1: Keep only required data variables
self.ds = self.ds[list(data_vars_to_keep)]

# Step 2: Drop coordinates not in used_dims
coords_to_drop = [c for c in self.ds.coords if c not in used_dims]

if coords_to_drop:
self.ds = self.ds.drop_vars(coords_to_drop)
print(
f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency"
f"✓ Dropped {len(coords_to_drop)} unused coordinate(s): {coords_to_drop}"
)
else:
print(f"✓ Validated compatible temporal frequency: {detected_freq}")
except (FrequencyMismatchError, IncompatibleFrequencyError) as e:
raise e # Re-raise these specific errors as-is
except InterruptedError as e:
raise e # Re-raise user abort
except Exception as e:
warnings.warn(
f"Could not validate temporal frequency: {e}. "
f"Proceeding with concatenation but results may be inconsistent."
)

self.ds = xr.open_mfdataset(
self.input_paths,
combine="nested", # avoids costly dimension alignment
concat_dim="time",
engine="netcdf4",
decode_cf=False,
chunks={},
preprocess=_preprocess,
parallel=True, # <--- enables concurrent preprocessing
)
else:
# Original file-based loading logic
def _preprocess(ds):
return ds[list(required_vars & set(ds.data_vars))]

# Validate frequency consistency and CMIP6 compatibility before concatenation
if self.validate_frequency and len(self.input_paths) > 0:
try:
# Enhanced validation with CMIP6 frequency compatibility
detected_freq, resampling_required = (
validate_cmip6_frequency_compatibility(
self.input_paths,
self.compound_name,
time_coord="time",
interactive=True,
)
)
if resampling_required:
print(
f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency"
)
else:
print(
f"✓ Validated compatible temporal frequency: {detected_freq}"
)
except (FrequencyMismatchError, IncompatibleFrequencyError) as e:
raise e # Re-raise these specific errors as-is
except InterruptedError as e:
raise e # Re-raise user abort
except Exception as e:
warnings.warn(
f"Could not validate temporal frequency: {e}. "
f"Proceeding with concatenation but results may be inconsistent."
)

self.ds = xr.open_mfdataset(
self.input_paths,
combine="nested", # avoids costly dimension alignment
concat_dim="time",
engine="netcdf4",
decode_cf=False,
chunks={},
preprocess=_preprocess,
parallel=True, # <--- enables concurrent preprocessing
)

# Apply temporal resampling if enabled and needed
if self.enable_resampling and self.compound_name:
Expand Down Expand Up @@ -312,6 +397,73 @@ def _preprocess(ds):
self.ds = self.chunker.rechunk_dataset(self.ds)
print("✅ Dataset rechunking completed")

def _ensure_numeric_time_coordinates(self, ds: xr.Dataset) -> xr.Dataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method to convert cftime format to numeric values

"""
Convert cftime objects in time-related coordinates to numeric values.

This safeguard prevents TypeError when cftime objects are implicitly
cast to numeric types in downstream operations (e.g., atmosphere.py line 174).

Args:
ds: Input dataset that may contain cftime coordinates

Returns:
Dataset with numeric time coordinates
"""
# List of common time-related coordinate names to check
time_coords = ["time", "time_bnds", "time_bounds"]

for coord_name in time_coords:
if coord_name not in ds.coords:
continue

coord = ds[coord_name]

# Check if coordinate contains cftime objects
if coord.size > 0:
# Get first value to check type
first_val = coord.values.flat[0] if coord.values.size > 0 else None

if first_val is not None and isinstance(first_val, cftime.datetime):
# Extract time encoding attributes
units = coord.attrs.get("units")
calendar = coord.attrs.get("calendar", "proleptic_gregorian")

if units is None:
warnings.warn(
f"Coordinate '{coord_name}' contains cftime objects but has no 'units' attribute. "
f"Using default: 'days since 0001-01-01'. "
f"Results may be incorrect.",
UserWarning,
)
units = "days since 0001-01-01"

# Convert cftime to numeric
try:
numeric_values = date2num(
coord.values, units=units, calendar=calendar
)

# Create new attributes dict with units and calendar
new_attrs = coord.attrs.copy()
new_attrs["units"] = units
new_attrs["calendar"] = calendar
# Replace coordinate with numeric values, preserving attributes
ds[coord_name] = (coord.dims, numeric_values, new_attrs)

print(
f"✓ Converted '{coord_name}' from cftime to numeric ({units}, {calendar})"
)

except Exception as e:
warnings.warn(
f"Failed to convert '{coord_name}' from cftime to numeric: {e}. "
f"This may cause errors in downstream processing.",
UserWarning,
)

return ds

def sort_time_dimension(self):
if "time" in self.ds.dims:
self.ds = self.ds.sortby("time")
Expand Down
Loading