-
Notifications
You must be signed in to change notification settings - Fork 0
Add xarray Dataset/DataArray support to ACCESS_ESM_CMORiser #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
635f913
127a952
3cbfa36
c12a876
bb421a8
5875258
06845ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,12 @@ | |
| from pathlib import Path | ||
| from typing import Any, Dict, List, Optional, Union | ||
|
|
||
| import cftime | ||
| import dask.array as da | ||
| import netCDF4 as nc | ||
| import psutil | ||
| import xarray as xr | ||
| from cftime import num2date | ||
| from cftime import date2num, num2date | ||
| from dask.distributed import get_client | ||
|
|
||
| from access_moppy.utilities import ( | ||
|
|
@@ -172,7 +173,8 @@ class CMIP6_CMORiser: | |
|
|
||
| def __init__( | ||
| self, | ||
| input_paths: Union[str, List[str]], | ||
| input_data: Optional[Union[str, List[str], xr.Dataset, xr.DataArray]] = None, | ||
| *, | ||
| output_path: str, | ||
| cmip6_vocab: Any, | ||
| variable_mapping: Dict[str, Any], | ||
|
|
@@ -185,10 +187,44 @@ def __init__( | |
| chunk_size_mb: float = 4.0, | ||
| enable_compression: bool = True, | ||
| compression_level: int = 4, | ||
| # Backward compatibility | ||
| input_paths: Optional[Union[str, List[str]]] = None, | ||
| ): | ||
| self.input_paths = ( | ||
| input_paths if isinstance(input_paths, list) else [input_paths] | ||
| ) | ||
| # Handle backward compatibility and validation | ||
| if input_paths is not None and input_data is None: | ||
| warnings.warn( | ||
| "The 'input_paths' parameter is deprecated. Use 'input_data' instead.", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| input_data = input_paths | ||
| elif input_paths is not None and input_data is not None: | ||
| raise ValueError( | ||
| "Cannot specify both 'input_data' and 'input_paths'. Use 'input_data'." | ||
| ) | ||
| elif input_paths is None and input_data is None: | ||
| raise ValueError("Must specify either 'input_data' or 'input_paths'.") | ||
|
|
||
| # Determine input type and handle appropriately | ||
| self.input_is_xarray = isinstance(input_data, (xr.Dataset, xr.DataArray)) | ||
|
|
||
| if self.input_is_xarray: | ||
| # For xarray inputs, store the dataset directly | ||
| if isinstance(input_data, xr.DataArray): | ||
| self.input_dataset = input_data.to_dataset() | ||
| else: | ||
| self.input_dataset = input_data | ||
| self.input_paths = [] # Empty list for compatibility | ||
| else: | ||
| # For file paths, store as before | ||
| self.input_paths = ( | ||
| input_data | ||
| if isinstance(input_data, list) | ||
| else [input_data] | ||
| if input_data | ||
| else [] | ||
| ) | ||
| self.input_dataset = None | ||
| self.output_path = output_path | ||
| # Extract cmor_name from compound_name | ||
| _, self.cmor_name = compound_name.split(".") | ||
|
|
@@ -227,53 +263,102 @@ def __repr__(self): | |
|
|
||
| def load_dataset(self, required_vars: Optional[List[str]] = None): | ||
| """ | ||
| Load dataset from input files with optional frequency validation. | ||
| Load dataset from input files or use provided xarray objects with optional frequency validation. | ||
|
|
||
| Args: | ||
| required_vars: Optional list of required variables to extract | ||
| """ | ||
|
|
||
| def _preprocess(ds): | ||
| return ds[list(required_vars & set(ds.data_vars))] | ||
|
|
||
| # Validate frequency consistency and CMIP6 compatibility before concatenation | ||
| if self.validate_frequency and len(self.input_paths) > 0: | ||
| try: | ||
| # Enhanced validation with CMIP6 frequency compatibility | ||
| detected_freq, resampling_required = ( | ||
| validate_cmip6_frequency_compatibility( | ||
| self.input_paths, | ||
| self.compound_name, | ||
| time_coord="time", | ||
| interactive=True, | ||
| # If input is already an xarray object, use it directly | ||
| if self.input_is_xarray: | ||
| self.ds = ( | ||
| self.input_dataset.copy() | ||
| ) # Make a copy to avoid modifying original | ||
|
|
||
| # SAFEGUARD: Convert cftime coordinates to numeric if present | ||
| self.ds = self._ensure_numeric_time_coordinates(self.ds) | ||
|
|
||
| # Apply variable filtering if required_vars is specified | ||
| if required_vars: | ||
| available_vars = set(self.ds.data_vars) | set(self.ds.coords) | ||
| vars_to_keep = set(required_vars) & available_vars | ||
| if vars_to_keep != set(required_vars): | ||
| missing_vars = set(required_vars) - available_vars | ||
| warnings.warn( | ||
| f"Some required variables not found in dataset: {missing_vars}. " | ||
| f"Available variables: {available_vars}" | ||
| ) | ||
| ) | ||
| if resampling_required: | ||
|
|
||
| # Keep only required data variables | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dropped redundant coordinates and dimensions to prevent them from affecting other parts of the workflow. These steps were previously handled implicitly by xr.open_mfdataset() and now need to be handled explicitly. |
||
| data_vars_to_keep = vars_to_keep & set(self.ds.data_vars) | ||
|
|
||
| # Collect dimensions used by these data variables | ||
| used_dims = set() | ||
| for var in data_vars_to_keep: | ||
| used_dims.update(self.ds[var].dims) | ||
|
|
||
| # Exclude auxiliary time dimension | ||
| if "time_0" in used_dims: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| self.ds = self.ds.isel(time_0=0, drop=True) | ||
| used_dims.remove("time_0") | ||
|
|
||
| # Step 1: Keep only required data variables | ||
| self.ds = self.ds[list(data_vars_to_keep)] | ||
|
|
||
| # Step 2: Drop coordinates not in used_dims | ||
| coords_to_drop = [c for c in self.ds.coords if c not in used_dims] | ||
|
|
||
| if coords_to_drop: | ||
| self.ds = self.ds.drop_vars(coords_to_drop) | ||
| print( | ||
| f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency" | ||
| f"✓ Dropped {len(coords_to_drop)} unused coordinate(s): {coords_to_drop}" | ||
| ) | ||
| else: | ||
| print(f"✓ Validated compatible temporal frequency: {detected_freq}") | ||
| except (FrequencyMismatchError, IncompatibleFrequencyError) as e: | ||
| raise e # Re-raise these specific errors as-is | ||
| except InterruptedError as e: | ||
| raise e # Re-raise user abort | ||
| except Exception as e: | ||
| warnings.warn( | ||
| f"Could not validate temporal frequency: {e}. " | ||
| f"Proceeding with concatenation but results may be inconsistent." | ||
| ) | ||
|
|
||
| self.ds = xr.open_mfdataset( | ||
| self.input_paths, | ||
| combine="nested", # avoids costly dimension alignment | ||
| concat_dim="time", | ||
| engine="netcdf4", | ||
| decode_cf=False, | ||
| chunks={}, | ||
| preprocess=_preprocess, | ||
| parallel=True, # <--- enables concurrent preprocessing | ||
| ) | ||
| else: | ||
| # Original file-based loading logic | ||
| def _preprocess(ds): | ||
| return ds[list(required_vars & set(ds.data_vars))] | ||
|
|
||
| # Validate frequency consistency and CMIP6 compatibility before concatenation | ||
| if self.validate_frequency and len(self.input_paths) > 0: | ||
| try: | ||
| # Enhanced validation with CMIP6 frequency compatibility | ||
| detected_freq, resampling_required = ( | ||
| validate_cmip6_frequency_compatibility( | ||
| self.input_paths, | ||
| self.compound_name, | ||
| time_coord="time", | ||
| interactive=True, | ||
| ) | ||
| ) | ||
| if resampling_required: | ||
| print( | ||
| f"✓ Temporal resampling will be applied: {detected_freq} → CMIP6 target frequency" | ||
| ) | ||
| else: | ||
| print( | ||
| f"✓ Validated compatible temporal frequency: {detected_freq}" | ||
| ) | ||
| except (FrequencyMismatchError, IncompatibleFrequencyError) as e: | ||
| raise e # Re-raise these specific errors as-is | ||
| except InterruptedError as e: | ||
| raise e # Re-raise user abort | ||
| except Exception as e: | ||
| warnings.warn( | ||
| f"Could not validate temporal frequency: {e}. " | ||
| f"Proceeding with concatenation but results may be inconsistent." | ||
| ) | ||
|
|
||
| self.ds = xr.open_mfdataset( | ||
| self.input_paths, | ||
| combine="nested", # avoids costly dimension alignment | ||
| concat_dim="time", | ||
| engine="netcdf4", | ||
| decode_cf=False, | ||
| chunks={}, | ||
| preprocess=_preprocess, | ||
| parallel=True, # <--- enables concurrent preprocessing | ||
| ) | ||
|
|
||
| # Apply temporal resampling if enabled and needed | ||
| if self.enable_resampling and self.compound_name: | ||
|
|
@@ -312,6 +397,73 @@ def _preprocess(ds): | |
| self.ds = self.chunker.rechunk_dataset(self.ds) | ||
| print("✅ Dataset rechunking completed") | ||
|
|
||
| def _ensure_numeric_time_coordinates(self, ds: xr.Dataset) -> xr.Dataset: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Method to convert cftime format to numeric values |
||
| """ | ||
| Convert cftime objects in time-related coordinates to numeric values. | ||
|
|
||
| This safeguard prevents TypeError when cftime objects are implicitly | ||
| cast to numeric types in downstream operations (e.g., atmosphere.py line 174). | ||
|
|
||
| Args: | ||
| ds: Input dataset that may contain cftime coordinates | ||
|
|
||
| Returns: | ||
| Dataset with numeric time coordinates | ||
| """ | ||
| # List of common time-related coordinate names to check | ||
| time_coords = ["time", "time_bnds", "time_bounds"] | ||
|
|
||
| for coord_name in time_coords: | ||
| if coord_name not in ds.coords: | ||
| continue | ||
|
|
||
| coord = ds[coord_name] | ||
|
|
||
| # Check if coordinate contains cftime objects | ||
| if coord.size > 0: | ||
| # Get first value to check type | ||
| first_val = coord.values.flat[0] if coord.values.size > 0 else None | ||
|
|
||
| if first_val is not None and isinstance(first_val, cftime.datetime): | ||
| # Extract time encoding attributes | ||
| units = coord.attrs.get("units") | ||
| calendar = coord.attrs.get("calendar", "proleptic_gregorian") | ||
|
|
||
| if units is None: | ||
| warnings.warn( | ||
| f"Coordinate '{coord_name}' contains cftime objects but has no 'units' attribute. " | ||
| f"Using default: 'days since 0001-01-01'. " | ||
| f"Results may be incorrect.", | ||
| UserWarning, | ||
| ) | ||
| units = "days since 0001-01-01" | ||
|
|
||
| # Convert cftime to numeric | ||
| try: | ||
| numeric_values = date2num( | ||
| coord.values, units=units, calendar=calendar | ||
| ) | ||
|
|
||
| # Create new attributes dict with units and calendar | ||
| new_attrs = coord.attrs.copy() | ||
| new_attrs["units"] = units | ||
| new_attrs["calendar"] = calendar | ||
| # Replace coordinate with numeric values, preserving attributes | ||
| ds[coord_name] = (coord.dims, numeric_values, new_attrs) | ||
|
|
||
| print( | ||
| f"✓ Converted '{coord_name}' from cftime to numeric ({units}, {calendar})" | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| warnings.warn( | ||
| f"Failed to convert '{coord_name}' from cftime to numeric: {e}. " | ||
| f"This may cause errors in downstream processing.", | ||
| UserWarning, | ||
| ) | ||
|
|
||
| return ds | ||
|
|
||
| def sort_time_dimension(self): | ||
| if "time" in self.ds.dims: | ||
| self.ds = self.ds.sortby("time") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a safeguard to handle cases where the input data use cftime in time-related coordinates and variables.