33from pathlib import Path
44from typing import Any , Dict , List , Optional , Union
55
6+ import cftime
67import dask .array as da
78import netCDF4 as nc
89import psutil
910import xarray as xr
10- from cftime import num2date
11+ from cftime import date2num , num2date
1112from dask .distributed import get_client
1213
1314from access_moppy .utilities import (
@@ -172,7 +173,8 @@ class CMIP6_CMORiser:
172173
173174 def __init__ (
174175 self ,
175- input_paths : Union [str , List [str ]],
176+ input_data : Optional [Union [str , List [str ], xr .Dataset , xr .DataArray ]] = None ,
177+ * ,
176178 output_path : str ,
177179 cmip6_vocab : Any ,
178180 variable_mapping : Dict [str , Any ],
@@ -185,10 +187,44 @@ def __init__(
185187 chunk_size_mb : float = 4.0 ,
186188 enable_compression : bool = True ,
187189 compression_level : int = 4 ,
190+ # Backward compatibility
191+ input_paths : Optional [Union [str , List [str ]]] = None ,
188192 ):
189- self .input_paths = (
190- input_paths if isinstance (input_paths , list ) else [input_paths ]
191- )
193+ # Handle backward compatibility and validation
194+ if input_paths is not None and input_data is None :
195+ warnings .warn (
196+ "The 'input_paths' parameter is deprecated. Use 'input_data' instead." ,
197+ DeprecationWarning ,
198+ stacklevel = 2 ,
199+ )
200+ input_data = input_paths
201+ elif input_paths is not None and input_data is not None :
202+ raise ValueError (
203+ "Cannot specify both 'input_data' and 'input_paths'. Use 'input_data'."
204+ )
205+ elif input_paths is None and input_data is None :
206+ raise ValueError ("Must specify either 'input_data' or 'input_paths'." )
207+
208+ # Determine input type and handle appropriately
209+ self .input_is_xarray = isinstance (input_data , (xr .Dataset , xr .DataArray ))
210+
211+ if self .input_is_xarray :
212+ # For xarray inputs, store the dataset directly
213+ if isinstance (input_data , xr .DataArray ):
214+ self .input_dataset = input_data .to_dataset ()
215+ else :
216+ self .input_dataset = input_data
217+ self .input_paths = [] # Empty list for compatibility
218+ else :
219+ # For file paths, store as before
220+ self .input_paths = (
221+ input_data
222+ if isinstance (input_data , list )
223+ else [input_data ]
224+ if input_data
225+ else []
226+ )
227+ self .input_dataset = None
192228 self .output_path = output_path
193229 # Extract cmor_name from compound_name
194230 _ , self .cmor_name = compound_name .split ("." )
@@ -227,53 +263,102 @@ def __repr__(self):
227263
228264 def load_dataset (self , required_vars : Optional [List [str ]] = None ):
229265 """
230- Load dataset from input files with optional frequency validation.
266+ Load dataset from input files or use provided xarray objects with optional frequency validation.
231267
232268 Args:
233269 required_vars: Optional list of required variables to extract
234270 """
235271
236- def _preprocess (ds ):
237- return ds [list (required_vars & set (ds .data_vars ))]
238-
239- # Validate frequency consistency and CMIP6 compatibility before concatenation
240- if self .validate_frequency and len (self .input_paths ) > 0 :
241- try :
242- # Enhanced validation with CMIP6 frequency compatibility
243- detected_freq , resampling_required = (
244- validate_cmip6_frequency_compatibility (
245- self .input_paths ,
246- self .compound_name ,
247- time_coord = "time" ,
248- interactive = True ,
272+ # If input is already an xarray object, use it directly
273+ if self .input_is_xarray :
274+ self .ds = (
275+ self .input_dataset .copy ()
276+ ) # Make a copy to avoid modifying original
277+
278+ # SAFEGUARD: Convert cftime coordinates to numeric if present
279+ self .ds = self ._ensure_numeric_time_coordinates (self .ds )
280+
281+ # Apply variable filtering if required_vars is specified
282+ if required_vars :
283+ available_vars = set (self .ds .data_vars ) | set (self .ds .coords )
284+ vars_to_keep = set (required_vars ) & available_vars
285+ if vars_to_keep != set (required_vars ):
286+ missing_vars = set (required_vars ) - available_vars
287+ warnings .warn (
288+ f"Some required variables not found in dataset: { missing_vars } . "
289+ f"Available variables: { available_vars } "
249290 )
250- )
251- if resampling_required :
291+
292+ # Keep only required data variables
293+ data_vars_to_keep = vars_to_keep & set (self .ds .data_vars )
294+
295+ # Collect dimensions used by these data variables
296+ used_dims = set ()
297+ for var in data_vars_to_keep :
298+ used_dims .update (self .ds [var ].dims )
299+
300+ # Exclude auxiliary time dimension
301+ if "time_0" in used_dims :
302+ self .ds = self .ds .isel (time_0 = 0 , drop = True )
303+ used_dims .remove ("time_0" )
304+
305+ # Step 1: Keep only required data variables
306+ self .ds = self .ds [list (data_vars_to_keep )]
307+
308+ # Step 2: Drop coordinates not in used_dims
309+ coords_to_drop = [c for c in self .ds .coords if c not in used_dims ]
310+
311+ if coords_to_drop :
312+ self .ds = self .ds .drop_vars (coords_to_drop )
252313 print (
253- f"✓ Temporal resampling will be applied : { detected_freq } → CMIP6 target frequency "
314+ f"✓ Dropped { len ( coords_to_drop ) } unused coordinate(s) : { coords_to_drop } "
254315 )
255- else :
256- print (f"✓ Validated compatible temporal frequency: { detected_freq } " )
257- except (FrequencyMismatchError , IncompatibleFrequencyError ) as e :
258- raise e # Re-raise these specific errors as-is
259- except InterruptedError as e :
260- raise e # Re-raise user abort
261- except Exception as e :
262- warnings .warn (
263- f"Could not validate temporal frequency: { e } . "
264- f"Proceeding with concatenation but results may be inconsistent."
265- )
266316
267- self .ds = xr .open_mfdataset (
268- self .input_paths ,
269- combine = "nested" , # avoids costly dimension alignment
270- concat_dim = "time" ,
271- engine = "netcdf4" ,
272- decode_cf = False ,
273- chunks = {},
274- preprocess = _preprocess ,
275- parallel = True , # <--- enables concurrent preprocessing
276- )
317+ else :
318+ # Original file-based loading logic
319+ def _preprocess (ds ):
320+ return ds [list (required_vars & set (ds .data_vars ))]
321+
322+ # Validate frequency consistency and CMIP6 compatibility before concatenation
323+ if self .validate_frequency and len (self .input_paths ) > 0 :
324+ try :
325+ # Enhanced validation with CMIP6 frequency compatibility
326+ detected_freq , resampling_required = (
327+ validate_cmip6_frequency_compatibility (
328+ self .input_paths ,
329+ self .compound_name ,
330+ time_coord = "time" ,
331+ interactive = True ,
332+ )
333+ )
334+ if resampling_required :
335+ print (
336+ f"✓ Temporal resampling will be applied: { detected_freq } → CMIP6 target frequency"
337+ )
338+ else :
339+ print (
340+ f"✓ Validated compatible temporal frequency: { detected_freq } "
341+ )
342+ except (FrequencyMismatchError , IncompatibleFrequencyError ) as e :
343+ raise e # Re-raise these specific errors as-is
344+ except InterruptedError as e :
345+ raise e # Re-raise user abort
346+ except Exception as e :
347+ warnings .warn (
348+ f"Could not validate temporal frequency: { e } . "
349+ f"Proceeding with concatenation but results may be inconsistent."
350+ )
351+
352+ self .ds = xr .open_mfdataset (
353+ self .input_paths ,
354+ combine = "nested" , # avoids costly dimension alignment
355+ concat_dim = "time" ,
356+ engine = "netcdf4" ,
357+ decode_cf = False ,
358+ chunks = {},
359+ preprocess = _preprocess ,
360+ parallel = True , # <--- enables concurrent preprocessing
361+ )
277362
278363 # Apply temporal resampling if enabled and needed
279364 if self .enable_resampling and self .compound_name :
@@ -312,6 +397,73 @@ def _preprocess(ds):
312397 self .ds = self .chunker .rechunk_dataset (self .ds )
313398 print ("✅ Dataset rechunking completed" )
314399
400+ def _ensure_numeric_time_coordinates (self , ds : xr .Dataset ) -> xr .Dataset :
401+ """
402+ Convert cftime objects in time-related coordinates to numeric values.
403+
404+ This safeguard prevents TypeError when cftime objects are implicitly
405+ cast to numeric types in downstream operations (e.g., atmosphere.py line 174).
406+
407+ Args:
408+ ds: Input dataset that may contain cftime coordinates
409+
410+ Returns:
411+ Dataset with numeric time coordinates
412+ """
413+ # List of common time-related coordinate names to check
414+ time_coords = ["time" , "time_bnds" , "time_bounds" ]
415+
416+ for coord_name in time_coords :
417+ if coord_name not in ds .coords :
418+ continue
419+
420+ coord = ds [coord_name ]
421+
422+ # Check if coordinate contains cftime objects
423+ if coord .size > 0 :
424+ # Get first value to check type
425+ first_val = coord .values .flat [0 ] if coord .values .size > 0 else None
426+
427+ if first_val is not None and isinstance (first_val , cftime .datetime ):
428+ # Extract time encoding attributes
429+ units = coord .attrs .get ("units" )
430+ calendar = coord .attrs .get ("calendar" , "proleptic_gregorian" )
431+
432+ if units is None :
433+ warnings .warn (
434+ f"Coordinate '{ coord_name } ' contains cftime objects but has no 'units' attribute. "
435+ f"Using default: 'days since 0001-01-01'. "
436+ f"Results may be incorrect." ,
437+ UserWarning ,
438+ )
439+ units = "days since 0001-01-01"
440+
441+ # Convert cftime to numeric
442+ try :
443+ numeric_values = date2num (
444+ coord .values , units = units , calendar = calendar
445+ )
446+
447+ # Create new attributes dict with units and calendar
448+ new_attrs = coord .attrs .copy ()
449+ new_attrs ["units" ] = units
450+ new_attrs ["calendar" ] = calendar
451+ # Replace coordinate with numeric values, preserving attributes
452+ ds [coord_name ] = (coord .dims , numeric_values , new_attrs )
453+
454+ print (
455+ f"✓ Converted '{ coord_name } ' from cftime to numeric ({ units } , { calendar } )"
456+ )
457+
458+ except Exception as e :
459+ warnings .warn (
460+ f"Failed to convert '{ coord_name } ' from cftime to numeric: { e } . "
461+ f"This may cause errors in downstream processing." ,
462+ UserWarning ,
463+ )
464+
465+ return ds
466+
315467 def sort_time_dimension (self ):
316468 if "time" in self .ds .dims :
317469 self .ds = self .ds .sortby ("time" )
0 commit comments