Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 103 additions & 43 deletions src/access_moppy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ class CMIP6_CMORiser:

def __init__(
self,
input_paths: Union[str, List[str]],
input_data: Optional[Union[str, List[str], xr.Dataset, xr.DataArray]] = None,
*,
output_path: str,
cmip6_vocab: Any,
variable_mapping: Dict[str, Any],
Expand All @@ -185,10 +186,44 @@ def __init__(
chunk_size_mb: float = 4.0,
enable_compression: bool = True,
compression_level: int = 4,
# Backward compatibility
input_paths: Optional[Union[str, List[str]]] = None,
):
self.input_paths = (
input_paths if isinstance(input_paths, list) else [input_paths]
)
# Handle backward compatibility and validation
if input_paths is not None and input_data is None:
warnings.warn(
"The 'input_paths' parameter is deprecated. Use 'input_data' instead.",
DeprecationWarning,
stacklevel=2,
)
input_data = input_paths
elif input_paths is not None and input_data is not None:
raise ValueError(
"Cannot specify both 'input_data' and 'input_paths'. Use 'input_data'."
)
elif input_paths is None and input_data is None:
raise ValueError("Must specify either 'input_data' or 'input_paths'.")

# Determine input type and handle appropriately
self.input_is_xarray = isinstance(input_data, (xr.Dataset, xr.DataArray))

if self.input_is_xarray:
# For xarray inputs, store the dataset directly
if isinstance(input_data, xr.DataArray):
self.input_dataset = input_data.to_dataset()
else:
self.input_dataset = input_data
self.input_paths = [] # Empty list for compatibility
else:
# For file paths, store as before
self.input_paths = (
input_data
if isinstance(input_data, list)
else [input_data]
if input_data
else []
)
self.input_dataset = None
self.output_path = output_path
# Extract cmor_name from compound_name
_, self.cmor_name = compound_name.split(".")
Expand Down Expand Up @@ -227,53 +262,78 @@ def __repr__(self):

def load_dataset(self, required_vars: Optional[List[str]] = None):
"""
Load dataset from input files with optional frequency validation.
Load dataset from input files or use provided xarray objects with optional frequency validation.

Args:
required_vars: Optional list of required variables to extract
"""

def _preprocess(ds):
return ds[list(required_vars & set(ds.data_vars))]

# Validate frequency consistency and CMIP6 compatibility before concatenation
if self.validate_frequency and len(self.input_paths) > 0:
try:
# Enhanced validation with CMIP6 frequency compatibility
detected_freq, resampling_required = (
validate_cmip6_frequency_compatibility(
self.input_paths,
self.compound_name,
time_coord="time",
interactive=True,
# If input is already an xarray object, use it directly
if self.input_is_xarray:
self.ds = (
self.input_dataset.copy()
) # Make a copy to avoid modifying original

# Apply variable filtering if required_vars is specified
if required_vars:
available_vars = set(self.ds.data_vars) | set(self.ds.coords)
vars_to_keep = set(required_vars) & available_vars
if vars_to_keep != set(required_vars):
missing_vars = set(required_vars) - available_vars
warnings.warn(
f"Some required variables not found in dataset: {missing_vars}. "
f"Available variables: {available_vars}"
)
)
if resampling_required:
print(
f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency"
# Keep only required variables plus coordinates
coords_to_keep = set(self.ds.coords)
data_vars_to_keep = vars_to_keep & set(self.ds.data_vars)
all_vars_to_keep = list(coords_to_keep | data_vars_to_keep)
self.ds = self.ds[all_vars_to_keep]
else:
# Original file-based loading logic
def _preprocess(ds):
return ds[list(required_vars & set(ds.data_vars))]

# Validate frequency consistency and CMIP6 compatibility before concatenation
if self.validate_frequency and len(self.input_paths) > 0:
try:
# Enhanced validation with CMIP6 frequency compatibility
detected_freq, resampling_required = (
validate_cmip6_frequency_compatibility(
self.input_paths,
self.compound_name,
time_coord="time",
interactive=True,
)
)
if resampling_required:
print(
f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency"
)
else:
print(
f"✓ Validated compatible temporal frequency: {detected_freq}"
)
except (FrequencyMismatchError, IncompatibleFrequencyError) as e:
raise e # Re-raise these specific errors as-is
except InterruptedError as e:
raise e # Re-raise user abort
except Exception as e:
warnings.warn(
f"Could not validate temporal frequency: {e}. "
f"Proceeding with concatenation but results may be inconsistent."
)
else:
print(f"✓ Validated compatible temporal frequency: {detected_freq}")
except (FrequencyMismatchError, IncompatibleFrequencyError) as e:
raise e # Re-raise these specific errors as-is
except InterruptedError as e:
raise e # Re-raise user abort
except Exception as e:
warnings.warn(
f"Could not validate temporal frequency: {e}. "
f"Proceeding with concatenation but results may be inconsistent."
)

self.ds = xr.open_mfdataset(
self.input_paths,
combine="nested", # avoids costly dimension alignment
concat_dim="time",
engine="netcdf4",
decode_cf=False,
chunks={},
preprocess=_preprocess,
parallel=True, # <--- enables concurrent preprocessing
)
self.ds = xr.open_mfdataset(
self.input_paths,
combine="nested", # avoids costly dimension alignment
concat_dim="time",
engine="netcdf4",
decode_cf=False,
chunks={},
preprocess=_preprocess,
parallel=True, # <--- enables concurrent preprocessing
)

# Apply temporal resampling if enabled and needed
if self.enable_resampling and self.compound_name:
Expand Down
65 changes: 59 additions & 6 deletions src/access_moppy/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

import xarray as xr

from access_moppy.atmosphere import CMIP6_Atmosphere_CMORiser
from access_moppy.defaults import _default_parent_info
from access_moppy.ocean import CMIP6_Ocean_CMORiser_OM2, CMIP6_Ocean_CMORiser_OM3
Expand All @@ -17,7 +19,8 @@ class ACCESS_ESM_CMORiser:

def __init__(
self,
input_paths: Union[str, list],
input_data: Optional[Union[str, list, xr.Dataset, xr.DataArray]] = None,
*,
compound_name: str,
experiment_id: str,
source_id: str,
Expand All @@ -31,10 +34,12 @@ def __init__(
validate_frequency: bool = True,
enable_resampling: bool = False,
resampling_method: str = "auto",
# Backward compatibility
input_paths: Optional[Union[str, list]] = None,
):
"""
Initializes the CMORiser with necessary parameters.
:param input_paths: Path(s) to input NetCDF files.
:param input_data: Path(s) to input NetCDF files, xarray Dataset, or xarray DataArray.
:param compound_name: CMOR variable name (e.g., 'Amon.tas').
:param experiment_id: CMIP6 experiment ID (e.g., 'historical').
:param source_id: CMIP6 source ID (e.g., 'ACCESS-ESM1-5').
Expand All @@ -48,9 +53,51 @@ def __init__(
:param validate_frequency: Whether to validate temporal frequency consistency across input files (default: True).
:param enable_resampling: Whether to enable automatic temporal resampling when frequency mismatches occur (default: False).
:param resampling_method: Method for temporal resampling ('auto', 'mean', 'sum', 'min', 'max', 'first', 'last') (default: 'auto').
:param input_paths: [DEPRECATED] Use input_data instead. Kept for backward compatibility.
"""

self.input_paths = input_paths
# Handle backward compatibility and validation
if input_paths is not None and input_data is None:
warnings.warn(
"The 'input_paths' parameter is deprecated. Use 'input_data' instead.",
DeprecationWarning,
stacklevel=2,
)
input_data = input_paths
elif input_paths is not None and input_data is not None:
raise ValueError(
"Cannot specify both 'input_data' and 'input_paths'. Use 'input_data'."
)
elif input_paths is None and input_data is None:
raise ValueError("Must specify either 'input_data' or 'input_paths'.")

# Determine input type and store appropriately
self.input_is_xarray = isinstance(input_data, (xr.Dataset, xr.DataArray))

if self.input_is_xarray:
# For xarray inputs, convert DataArray to Dataset if needed
if isinstance(input_data, xr.DataArray):
self.input_dataset = input_data.to_dataset()
else:
self.input_dataset = input_data
self.input_paths = [] # Empty list for compatibility
# Disable frequency validation for xarray inputs (already loaded)
if validate_frequency:
warnings.warn(
"Disabling frequency validation for xarray input (data is already loaded).",
UserWarning,
)
validate_frequency = False
else:
# For file paths, store as before
self.input_paths = (
input_data
if isinstance(input_data, list)
else [input_data]
if input_data
else []
)
self.input_dataset = None
self.validate_frequency = validate_frequency
self.enable_resampling = enable_resampling
self.resampling_method = resampling_method
Expand Down Expand Up @@ -95,7 +142,9 @@ def __init__(
table, _ = compound_name.split(".") # cmor_name now extracted internally
if table in ("Amon", "Lmon", "Emon"):
self.cmoriser = CMIP6_Atmosphere_CMORiser(
input_paths=self.input_paths,
input_data=self.input_dataset
if self.input_is_xarray
else self.input_paths,
output_path=str(self.output_path),
cmip6_vocab=self.vocab,
variable_mapping=self.variable_mapping,
Expand All @@ -110,7 +159,9 @@ def __init__(
# ACCESS-OM3 uses MOM6 (C-grid) — requires dedicated CMORiser implementation
# that handles C-grid supergrid logic, MOM6 metadata, and OM3-specific conventions
self.cmoriser = CMIP6_Ocean_CMORiser_OM3(
input_paths=self.input_paths,
input_data=self.input_dataset
if self.input_is_xarray
else self.input_paths,
output_path=str(self.output_path),
compound_name=self.compound_name,
cmip6_vocab=self.vocab,
Expand All @@ -121,7 +172,9 @@ def __init__(
# ACCESS-OM2 uses MOM5 (B-grid) — handled by a separate CMORiser class
# specialized for B-grid variable locations and OM2-specific metadata
self.cmoriser = CMIP6_Ocean_CMORiser_OM2(
input_paths=self.input_paths,
input_data=self.input_dataset
if self.input_is_xarray
else self.input_paths,
output_path=str(self.output_path),
compound_name=self.compound_name,
cmip6_vocab=self.vocab,
Expand Down
19 changes: 16 additions & 3 deletions src/access_moppy/ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Optional, Union

import numpy as np
import xarray as xr

from access_moppy.base import CMIP6_CMORiser
from access_moppy.derivations import custom_functions, evaluate_expression
Expand All @@ -16,7 +17,8 @@ class CMIP6_Ocean_CMORiser(CMIP6_CMORiser):

def __init__(
self,
input_paths: Union[str, List[str]],
input_data: Optional[Union[str, List[str], xr.Dataset, xr.DataArray]] = None,
*,
output_path: str,
cmip6_vocab: CMIP6Vocabulary,
variable_mapping: Dict[str, Any],
Expand All @@ -25,8 +27,11 @@ def __init__(
validate_frequency: bool = True,
enable_resampling: bool = False,
resampling_method: str = "auto",
# Backward compatibility
input_paths: Optional[Union[str, List[str]]] = None,
):
super().__init__(
input_data=input_data,
input_paths=input_paths,
output_path=output_path,
cmip6_vocab=cmip6_vocab,
Expand Down Expand Up @@ -176,14 +181,18 @@ class CMIP6_Ocean_CMORiser_OM2(CMIP6_Ocean_CMORiser):

def __init__(
self,
input_paths: Union[str, List[str]],
input_data: Optional[Union[str, List[str], xr.Dataset, xr.DataArray]] = None,
*,
output_path: str,
compound_name: str,
cmip6_vocab: CMIP6Vocabulary,
variable_mapping: Dict[str, Any],
drs_root: Optional[Path] = None,
# Backward compatibility
input_paths: Optional[Union[str, List[str]]] = None,
):
super().__init__(
input_data=input_data,
input_paths=input_paths,
output_path=output_path,
compound_name=compound_name,
Expand Down Expand Up @@ -234,14 +243,18 @@ class CMIP6_Ocean_CMORiser_OM3(CMIP6_Ocean_CMORiser):

def __init__(
self,
input_paths: Union[str, List[str]],
input_data: Optional[Union[str, List[str], xr.Dataset, xr.DataArray]] = None,
*,
output_path: str,
compound_name: str,
cmip6_vocab: CMIP6Vocabulary,
variable_mapping: Dict[str, Any],
drs_root: Optional[Path] = None,
# Backward compatibility
input_paths: Optional[Union[str, List[str]]] = None,
):
super().__init__(
input_data=input_data,
input_paths=input_paths,
output_path=output_path,
compound_name=compound_name,
Expand Down