Skip to content

Commit 0eb8d0e

Browse files
rbeucherrhaegar325
andauthored
Add xarray Dataset/DataArray support to ACCESS_ESM_CMORiser (#145)
* feat: Add xarray Dataset/DataArray support to ACCESS_ESM_CMORiser - Add new `input_data` parameter to accept xarray Dataset or DataArray objects - Maintain full backward compatibility with existing `input_paths` parameter - Automatically convert DataArrays to Datasets for processing - Skip frequency validation for xarray inputs (data already loaded) - Update all CMORiser subclasses (Atmosphere, Ocean OM2/OM3) to support new interface - Preserve all existing functionality (resampling, chunking, validation) - Add comprehensive parameter validation and deprecation warnings This enables in-memory processing workflows and integration with xarray-based analysis pipelines while maintaining compatibility with existing file-based workflows. All existing tests pass (34/34) confirming no breaking changes. * fix: Improve deprecation warnings and input handling in CMORiser classes * add bounds calculator * add safeguard for time related coords and convert cftime to numeric value * fix coords issue * rpe-commit fix * adjust format --------- Co-authored-by: rhaegar325 <[email protected]>
1 parent bc36128 commit 0eb8d0e

File tree

5 files changed

+856
-53
lines changed

5 files changed

+856
-53
lines changed

src/access_moppy/atmosphere.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import warnings
2+
13
import numpy as np
24
import xarray as xr
35

46
from access_moppy.base import CMIP6_CMORiser
57
from access_moppy.derivations import custom_functions, evaluate_expression
8+
from access_moppy.utilities import (
9+
calculate_latitude_bounds,
10+
calculate_longitude_bounds,
11+
calculate_time_bounds,
12+
)
613

714

815
class CMIP6_Atmosphere_CMORiser(CMIP6_CMORiser):
@@ -38,6 +45,56 @@ def select_and_process_variables(self):
3845
self.load_dataset(required_vars=required_vars)
3946
self.sort_time_dimension()
4047

48+
# Calculate missing bounds variables
49+
for bnds_var in bnds_required:
50+
if bnds_var not in self.ds.data_vars and bnds_var not in self.ds.coords:
51+
# Extract coordinate name by removing "_bnds" suffix
52+
coord_name = bnds_var.replace("_bnds", "")
53+
54+
if coord_name not in self.ds.coords:
55+
raise ValueError(
56+
f"Cannot calculate {bnds_var}: coordinate '{coord_name}' not found in dataset"
57+
)
58+
59+
# Warn user that bounds are missing and will be calculated automatically
60+
warnings.warn(
61+
f"'{bnds_var}' not found in raw data. Automatically calculating bounds for '{coord_name}' coordinate.",
62+
UserWarning,
63+
stacklevel=2,
64+
)
65+
66+
# Determine which calculation function to use based on coordinate name
67+
if coord_name in ["time", "t"]:
68+
# Calculate time bounds - atmosphere uses "bnds"
69+
self.ds[bnds_var] = calculate_time_bounds(
70+
self.ds,
71+
time_coord=coord_name,
72+
bnds_name="bnds", # Atmosphere uses "bnds"
73+
)
74+
self.ds[coord_name].attrs["bounds"] = bnds_var
75+
76+
elif coord_name in ["lat", "latitude", "y"]:
77+
# Calculate latitude bounds - use "bnds" for atmosphere data
78+
self.ds[bnds_var] = calculate_latitude_bounds(
79+
self.ds, coord_name, bnds_name="bnds"
80+
)
81+
self.ds[coord_name].attrs["bounds"] = bnds_var
82+
83+
elif coord_name in ["lon", "longitude", "x"]:
84+
# Calculate longitude bounds - use "bnds" for atmosphere data
85+
self.ds[bnds_var] = calculate_longitude_bounds(
86+
self.ds, coord_name, bnds_name="bnds"
87+
)
88+
self.ds[coord_name].attrs["bounds"] = bnds_var
89+
90+
else:
91+
# For other coordinates, we could add more handlers or skip
92+
warnings.warn(
93+
f"No automatic calculation available for '{bnds_var}'. This may cause CMIP6 compliance issues.",
94+
UserWarning,
95+
stacklevel=2,
96+
)
97+
4198
# Handle the calculation type
4299
if calc["type"] == "direct":
43100
# If the calculation is direct, just rename the variable

src/access_moppy/base.py

Lines changed: 195 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from pathlib import Path
44
from typing import Any, Dict, List, Optional, Union
55

6+
import cftime
67
import dask.array as da
78
import netCDF4 as nc
89
import psutil
910
import xarray as xr
10-
from cftime import num2date
11+
from cftime import date2num, num2date
1112
from dask.distributed import get_client
1213

1314
from access_moppy.utilities import (
@@ -172,7 +173,8 @@ class CMIP6_CMORiser:
172173

173174
def __init__(
174175
self,
175-
input_paths: Union[str, List[str]],
176+
input_data: Optional[Union[str, List[str], xr.Dataset, xr.DataArray]] = None,
177+
*,
176178
output_path: str,
177179
cmip6_vocab: Any,
178180
variable_mapping: Dict[str, Any],
@@ -185,10 +187,44 @@ def __init__(
185187
chunk_size_mb: float = 4.0,
186188
enable_compression: bool = True,
187189
compression_level: int = 4,
190+
# Backward compatibility
191+
input_paths: Optional[Union[str, List[str]]] = None,
188192
):
189-
self.input_paths = (
190-
input_paths if isinstance(input_paths, list) else [input_paths]
191-
)
193+
# Handle backward compatibility and validation
194+
if input_paths is not None and input_data is None:
195+
warnings.warn(
196+
"The 'input_paths' parameter is deprecated. Use 'input_data' instead.",
197+
DeprecationWarning,
198+
stacklevel=2,
199+
)
200+
input_data = input_paths
201+
elif input_paths is not None and input_data is not None:
202+
raise ValueError(
203+
"Cannot specify both 'input_data' and 'input_paths'. Use 'input_data'."
204+
)
205+
elif input_paths is None and input_data is None:
206+
raise ValueError("Must specify either 'input_data' or 'input_paths'.")
207+
208+
# Determine input type and handle appropriately
209+
self.input_is_xarray = isinstance(input_data, (xr.Dataset, xr.DataArray))
210+
211+
if self.input_is_xarray:
212+
# For xarray inputs, store the dataset directly
213+
if isinstance(input_data, xr.DataArray):
214+
self.input_dataset = input_data.to_dataset()
215+
else:
216+
self.input_dataset = input_data
217+
self.input_paths = [] # Empty list for compatibility
218+
else:
219+
# For file paths, store as before
220+
self.input_paths = (
221+
input_data
222+
if isinstance(input_data, list)
223+
else [input_data]
224+
if input_data
225+
else []
226+
)
227+
self.input_dataset = None
192228
self.output_path = output_path
193229
# Extract cmor_name from compound_name
194230
_, self.cmor_name = compound_name.split(".")
@@ -227,53 +263,102 @@ def __repr__(self):
227263

228264
def load_dataset(self, required_vars: Optional[List[str]] = None):
229265
"""
230-
Load dataset from input files with optional frequency validation.
266+
Load dataset from input files or use provided xarray objects with optional frequency validation.
231267
232268
Args:
233269
required_vars: Optional list of required variables to extract
234270
"""
235271

236-
def _preprocess(ds):
237-
return ds[list(required_vars & set(ds.data_vars))]
238-
239-
# Validate frequency consistency and CMIP6 compatibility before concatenation
240-
if self.validate_frequency and len(self.input_paths) > 0:
241-
try:
242-
# Enhanced validation with CMIP6 frequency compatibility
243-
detected_freq, resampling_required = (
244-
validate_cmip6_frequency_compatibility(
245-
self.input_paths,
246-
self.compound_name,
247-
time_coord="time",
248-
interactive=True,
272+
# If input is already an xarray object, use it directly
273+
if self.input_is_xarray:
274+
self.ds = (
275+
self.input_dataset.copy()
276+
) # Make a copy to avoid modifying original
277+
278+
# SAFEGUARD: Convert cftime coordinates to numeric if present
279+
self.ds = self._ensure_numeric_time_coordinates(self.ds)
280+
281+
# Apply variable filtering if required_vars is specified
282+
if required_vars:
283+
available_vars = set(self.ds.data_vars) | set(self.ds.coords)
284+
vars_to_keep = set(required_vars) & available_vars
285+
if vars_to_keep != set(required_vars):
286+
missing_vars = set(required_vars) - available_vars
287+
warnings.warn(
288+
f"Some required variables not found in dataset: {missing_vars}. "
289+
f"Available variables: {available_vars}"
249290
)
250-
)
251-
if resampling_required:
291+
292+
# Keep only required data variables
293+
data_vars_to_keep = vars_to_keep & set(self.ds.data_vars)
294+
295+
# Collect dimensions used by these data variables
296+
used_dims = set()
297+
for var in data_vars_to_keep:
298+
used_dims.update(self.ds[var].dims)
299+
300+
# Exclude auxiliary time dimension
301+
if "time_0" in used_dims:
302+
self.ds = self.ds.isel(time_0=0, drop=True)
303+
used_dims.remove("time_0")
304+
305+
# Step 1: Keep only required data variables
306+
self.ds = self.ds[list(data_vars_to_keep)]
307+
308+
# Step 2: Drop coordinates not in used_dims
309+
coords_to_drop = [c for c in self.ds.coords if c not in used_dims]
310+
311+
if coords_to_drop:
312+
self.ds = self.ds.drop_vars(coords_to_drop)
252313
print(
253-
f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency"
314+
f"✓ Dropped {len(coords_to_drop)} unused coordinate(s): {coords_to_drop}"
254315
)
255-
else:
256-
print(f"✓ Validated compatible temporal frequency: {detected_freq}")
257-
except (FrequencyMismatchError, IncompatibleFrequencyError) as e:
258-
raise e # Re-raise these specific errors as-is
259-
except InterruptedError as e:
260-
raise e # Re-raise user abort
261-
except Exception as e:
262-
warnings.warn(
263-
f"Could not validate temporal frequency: {e}. "
264-
f"Proceeding with concatenation but results may be inconsistent."
265-
)
266316

267-
self.ds = xr.open_mfdataset(
268-
self.input_paths,
269-
combine="nested", # avoids costly dimension alignment
270-
concat_dim="time",
271-
engine="netcdf4",
272-
decode_cf=False,
273-
chunks={},
274-
preprocess=_preprocess,
275-
parallel=True, # <--- enables concurrent preprocessing
276-
)
317+
else:
318+
# Original file-based loading logic
319+
def _preprocess(ds):
320+
return ds[list(required_vars & set(ds.data_vars))]
321+
322+
# Validate frequency consistency and CMIP6 compatibility before concatenation
323+
if self.validate_frequency and len(self.input_paths) > 0:
324+
try:
325+
# Enhanced validation with CMIP6 frequency compatibility
326+
detected_freq, resampling_required = (
327+
validate_cmip6_frequency_compatibility(
328+
self.input_paths,
329+
self.compound_name,
330+
time_coord="time",
331+
interactive=True,
332+
)
333+
)
334+
if resampling_required:
335+
print(
336+
f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency"
337+
)
338+
else:
339+
print(
340+
f"✓ Validated compatible temporal frequency: {detected_freq}"
341+
)
342+
except (FrequencyMismatchError, IncompatibleFrequencyError) as e:
343+
raise e # Re-raise these specific errors as-is
344+
except InterruptedError as e:
345+
raise e # Re-raise user abort
346+
except Exception as e:
347+
warnings.warn(
348+
f"Could not validate temporal frequency: {e}. "
349+
f"Proceeding with concatenation but results may be inconsistent."
350+
)
351+
352+
self.ds = xr.open_mfdataset(
353+
self.input_paths,
354+
combine="nested", # avoids costly dimension alignment
355+
concat_dim="time",
356+
engine="netcdf4",
357+
decode_cf=False,
358+
chunks={},
359+
preprocess=_preprocess,
360+
parallel=True, # <--- enables concurrent preprocessing
361+
)
277362

278363
# Apply temporal resampling if enabled and needed
279364
if self.enable_resampling and self.compound_name:
@@ -312,6 +397,73 @@ def _preprocess(ds):
312397
self.ds = self.chunker.rechunk_dataset(self.ds)
313398
print("✅ Dataset rechunking completed")
314399

400+
def _ensure_numeric_time_coordinates(self, ds: xr.Dataset) -> xr.Dataset:
401+
"""
402+
Convert cftime objects in time-related coordinates to numeric values.
403+
404+
This safeguard prevents TypeError when cftime objects are implicitly
405+
cast to numeric types in downstream operations (e.g., atmosphere.py line 174).
406+
407+
Args:
408+
ds: Input dataset that may contain cftime coordinates
409+
410+
Returns:
411+
Dataset with numeric time coordinates
412+
"""
413+
# List of common time-related coordinate names to check
414+
time_coords = ["time", "time_bnds", "time_bounds"]
415+
416+
for coord_name in time_coords:
417+
if coord_name not in ds.coords:
418+
continue
419+
420+
coord = ds[coord_name]
421+
422+
# Check if coordinate contains cftime objects
423+
if coord.size > 0:
424+
# Get first value to check type
425+
first_val = coord.values.flat[0] if coord.values.size > 0 else None
426+
427+
if first_val is not None and isinstance(first_val, cftime.datetime):
428+
# Extract time encoding attributes
429+
units = coord.attrs.get("units")
430+
calendar = coord.attrs.get("calendar", "proleptic_gregorian")
431+
432+
if units is None:
433+
warnings.warn(
434+
f"Coordinate '{coord_name}' contains cftime objects but has no 'units' attribute. "
435+
f"Using default: 'days since 0001-01-01'. "
436+
f"Results may be incorrect.",
437+
UserWarning,
438+
)
439+
units = "days since 0001-01-01"
440+
441+
# Convert cftime to numeric
442+
try:
443+
numeric_values = date2num(
444+
coord.values, units=units, calendar=calendar
445+
)
446+
447+
# Create new attributes dict with units and calendar
448+
new_attrs = coord.attrs.copy()
449+
new_attrs["units"] = units
450+
new_attrs["calendar"] = calendar
451+
# Replace coordinate with numeric values, preserving attributes
452+
ds[coord_name] = (coord.dims, numeric_values, new_attrs)
453+
454+
print(
455+
f"✓ Converted '{coord_name}' from cftime to numeric ({units}, {calendar})"
456+
)
457+
458+
except Exception as e:
459+
warnings.warn(
460+
f"Failed to convert '{coord_name}' from cftime to numeric: {e}. "
461+
f"This may cause errors in downstream processing.",
462+
UserWarning,
463+
)
464+
465+
return ds
466+
315467
def sort_time_dimension(self):
316468
if "time" in self.ds.dims:
317469
self.ds = self.ds.sortby("time")

0 commit comments

Comments
 (0)