11import collections .abc
22import enum
3+ from collections import defaultdict
34from typing import List , Optional , Tuple
45
56import numpy as np
@@ -57,7 +58,10 @@ def __init__(
5758 self .time_index_number : int = None
5859 self ._process_selection (xr_ds )
5960
60- self .yt_coord_names = _convert_to_yt_internal_coords (self .selected_coords )
61+ xr_field = xr_ds .data_vars [fields [0 ]]
62+ self .yt_coord_names = _convert_to_yt_internal_coords (
63+ self .selected_coords , xr_field
64+ )
6165
6266 def _find_units (self , xr_ds ) -> dict :
6367 units = {}
@@ -332,10 +336,26 @@ def interp_validation(self, geometry):
332336}
333337
334338
335- known_coord_aliases = {}
339+ _default_known_coord_aliases = {}
336340for ky , vals in _coord_aliases .items ():
337341 for val in vals :
338- known_coord_aliases [val ] = ky
342+ _default_known_coord_aliases [val ] = ky
343+
344+ known_coord_aliases = _default_known_coord_aliases .copy ()
345+
346+
347+ def reset_coordinate_aliases ():
348+ kys_to_pop = [
349+ ky
350+ for ky in known_coord_aliases .keys ()
351+ if ky not in _default_known_coord_aliases
352+ ]
353+ for ky in kys_to_pop :
354+ known_coord_aliases .pop (ky )
355+
356+ for ky , val in _default_known_coord_aliases .items ():
357+ known_coord_aliases [ky ] = val
358+
339359
340360_expected_yt_axes = {
341361 "cartesian" : set (["x" , "y" , "z" ]),
@@ -351,20 +371,55 @@ def interp_validation(self, geometry):
351371 _yt_coord_names += list (vals )
352372
353373
354- def _convert_to_yt_internal_coords (coord_list ):
374+ def _invert_cf_standard_names (standard_names : dict ):
375+ inverted_mapping = defaultdict (lambda : set ())
376+ for ky , vals in standard_names .items ():
377+ for val in vals :
378+ inverted_mapping [val ].add (ky )
379+ return inverted_mapping
380+
381+
382+ def _cf_xr_coord_disamb (
383+ cname : str , xr_field : xr .DataArray
384+ ) -> Tuple [Optional [str ], bool ]:
385+ # returns a tuple of (validated name, cf_xarray_is_installed)
386+ try :
387+ import cf_xarray as cfx # noqa: F401
388+ except ImportError :
389+ return None , False
390+
391+ nm_to_standard = _invert_cf_standard_names (xr_field .cf .standard_names )
392+ if cname in nm_to_standard :
393+ cf_standard_name = nm_to_standard [cname ]
394+ if len (cf_standard_name ):
395+ cf_standard_name = list (cf_standard_name )[0 ]
396+ if cf_standard_name in known_coord_aliases :
397+ return cf_standard_name , True
398+ return None , True
399+
400+
401+ def _convert_to_yt_internal_coords (coord_list : List [str ], xr_field : xr .DataArray ):
355402 yt_coords = []
356403 for c in coord_list :
357404 cname = c .lower ()
405+ cf_xarray_exists = None
358406 if cname in known_coord_aliases :
359- yt_coords . append ( known_coord_aliases [cname ])
407+ valid_coord_name = known_coord_aliases [cname ]
360408 elif cname in _yt_coord_names :
361- yt_coords . append ( cname )
409+ valid_coord_name = cname
362410 else :
363- raise ValueError (
411+ valid_coord_name , cf_xarray_exists = _cf_xr_coord_disamb (cname , xr_field )
412+ if valid_coord_name is None :
413+ msg = (
364414 f"{ c } is not a known coordinate. To load in yt, you "
365- f "must supply an alias via the yt_xarray.known_coord_aliases"
366- f " dictionary. "
415+ "must supply an alias via the yt_xarray.known_coord_aliases"
416+ " dictionary"
367417 )
418+ if cf_xarray_exists is False :
419+ msg += " or install cf_xarray to check for additional aliases."
420+ raise ValueError (msg )
421+
422+ yt_coords .append (valid_coord_name )
368423
369424 return yt_coords
370425
0 commit comments