Skip to content

Commit e0c80a4

Browse files
authored
Merge pull request #83 from chrishavlin/cfxr_coord_disamb
add coordinate disambiguation with cf_xarray
2 parents 30e4b1f + f099d79 commit e0c80a4

File tree

6 files changed

+1181
-34
lines changed

6 files changed

+1181
-34
lines changed

docs/examples/example_002_coord_aliases.ipynb

Lines changed: 1070 additions & 19 deletions
Large diffs are not rendered by default.

docs/faq.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
Some common problems (and how to solve them)
2-
============================================
1+
FAQ
2+
===
33

44
yt and xarray have many similarities in how they handle their Datasets, but
55
there are also many aspects that differ to varying degree. This page describes
66
some of the difficulties you may encounter while using yt_xarray to communicate
7-
between the two.
7+
between the two and how to solve those issues.
88

99
xarray datasets with a mix of dimensionality
1010
********************************************
@@ -26,7 +26,9 @@ yt datasets have a fixed expectation for coordinate names. In cartesian, these
2626
coordinate names are ``'x'``, ``'y'``, ``'z'`` while for geographic coordinate systems
2727
the coordinate names are ``'latitude'``, ``'longtiude'`` and then either ``'altitude'``
2828
or ``'depth'``. To work with xarray variables defined with coordinate names that
29-
differ from these, yt_xarray provides some coordinate aliasing.
29+
differ from these, yt_xarray provides some coordinate aliasing, which in part relies
30+
on `cf_xarray <https://cf-xarray.readthedocs.io>`_ (if it is installed) for
31+
additional conversion to standard names.
3032

3133
See :doc:`examples/example_002_coord_aliases` for an example.
3234

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies=['yt>=4.2.0', 'h5py>=3.4.0', 'pooch>=1.5.1', 'xarray']
2828
"Bug Tracker" = "https://github.com/data-exp-lab/yt_xarray/issues"
2929

3030
[project.optional-dependencies]
31-
full = ["netCDF4", "scipy", "dask[complete]"]
31+
full = ["netCDF4", "scipy", "dask[complete]", "cf_xarray"]
3232
test = ["pytest", "pytest-cov", "cartopy"]
3333
docs = ["Sphinx==7.2.6", "jinja2==3.1.2", "nbsphinx==0.9.3"]
3434

yt_xarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
# import the xarray accessor so it is registered with xarray
99

1010
from .accessor import YtAccessor
11-
from .accessor._xr_to_yt import known_coord_aliases
11+
from .accessor._xr_to_yt import known_coord_aliases, reset_coordinate_aliases
1212
from .yt_xarray import open_dataset

yt_xarray/accessor/_xr_to_yt.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections.abc
22
import enum
3+
from collections import defaultdict
34
from typing import List, Optional, Tuple
45

56
import numpy as np
@@ -57,7 +58,10 @@ def __init__(
5758
self.time_index_number: int = None
5859
self._process_selection(xr_ds)
5960

60-
self.yt_coord_names = _convert_to_yt_internal_coords(self.selected_coords)
61+
xr_field = xr_ds.data_vars[fields[0]]
62+
self.yt_coord_names = _convert_to_yt_internal_coords(
63+
self.selected_coords, xr_field
64+
)
6165

6266
def _find_units(self, xr_ds) -> dict:
6367
units = {}
@@ -332,10 +336,26 @@ def interp_validation(self, geometry):
332336
}
333337

334338

335-
known_coord_aliases = {}
339+
_default_known_coord_aliases = {}
336340
for ky, vals in _coord_aliases.items():
337341
for val in vals:
338-
known_coord_aliases[val] = ky
342+
_default_known_coord_aliases[val] = ky
343+
344+
known_coord_aliases = _default_known_coord_aliases.copy()
345+
346+
347+
def reset_coordinate_aliases():
348+
kys_to_pop = [
349+
ky
350+
for ky in known_coord_aliases.keys()
351+
if ky not in _default_known_coord_aliases
352+
]
353+
for ky in kys_to_pop:
354+
known_coord_aliases.pop(ky)
355+
356+
for ky, val in _default_known_coord_aliases.items():
357+
known_coord_aliases[ky] = val
358+
339359

340360
_expected_yt_axes = {
341361
"cartesian": set(["x", "y", "z"]),
@@ -351,20 +371,55 @@ def interp_validation(self, geometry):
351371
_yt_coord_names += list(vals)
352372

353373

354-
def _convert_to_yt_internal_coords(coord_list):
374+
def _invert_cf_standard_names(standard_names: dict):
375+
inverted_mapping = defaultdict(lambda: set())
376+
for ky, vals in standard_names.items():
377+
for val in vals:
378+
inverted_mapping[val].add(ky)
379+
return inverted_mapping
380+
381+
382+
def _cf_xr_coord_disamb(
383+
cname: str, xr_field: xr.DataArray
384+
) -> Tuple[Optional[str], bool]:
385+
# returns a tuple of (validated name, cf_xarray_is_installed)
386+
try:
387+
import cf_xarray as cfx # noqa: F401
388+
except ImportError:
389+
return None, False
390+
391+
nm_to_standard = _invert_cf_standard_names(xr_field.cf.standard_names)
392+
if cname in nm_to_standard:
393+
cf_standard_name = nm_to_standard[cname]
394+
if len(cf_standard_name):
395+
cf_standard_name = list(cf_standard_name)[0]
396+
if cf_standard_name in known_coord_aliases:
397+
return cf_standard_name, True
398+
return None, True
399+
400+
401+
def _convert_to_yt_internal_coords(coord_list: List[str], xr_field: xr.DataArray):
355402
yt_coords = []
356403
for c in coord_list:
357404
cname = c.lower()
405+
cf_xarray_exists = None
358406
if cname in known_coord_aliases:
359-
yt_coords.append(known_coord_aliases[cname])
407+
valid_coord_name = known_coord_aliases[cname]
360408
elif cname in _yt_coord_names:
361-
yt_coords.append(cname)
409+
valid_coord_name = cname
362410
else:
363-
raise ValueError(
411+
valid_coord_name, cf_xarray_exists = _cf_xr_coord_disamb(cname, xr_field)
412+
if valid_coord_name is None:
413+
msg = (
364414
f"{c} is not a known coordinate. To load in yt, you "
365-
f"must supply an alias via the yt_xarray.known_coord_aliases"
366-
f" dictionary."
415+
"must supply an alias via the yt_xarray.known_coord_aliases"
416+
" dictionary"
367417
)
418+
if cf_xarray_exists is False:
419+
msg += " or install cf_xarray to check for additional aliases."
420+
raise ValueError(msg)
421+
422+
yt_coords.append(valid_coord_name)
368423

369424
return yt_coords
370425

yt_xarray/tests/test_xr_to_yt.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import builtins
2+
13
import numpy as np
24
import pytest
35
import xarray as xr
@@ -535,6 +537,43 @@ def test_reversed_axis(stretched, use_callable, chunksizes):
535537
assert np.all(np.isfinite(vals))
536538

537539

540+
def test_cf_xarray_disambiguation():
541+
from cf_xarray.datasets import airds
542+
543+
# run the whole selection (will internally run coord disambiguation)
544+
sel = xr2yt.Selection(
545+
airds, fields=["air"], sel_dict={"time": 0}, sel_dict_type="isel"
546+
)
547+
xr_da = airds.air
548+
selected_names = []
549+
for c in sel.selected_coords:
550+
selected_names.append(xr2yt._cf_xr_coord_disamb(c, xr_da)[0])
551+
552+
assert "latitude" in selected_names
553+
assert "longitude" in selected_names
554+
555+
556+
def test_missing_cfxarray(monkeypatch):
557+
from cf_xarray.datasets import airds
558+
559+
def _bad_import(name, globals=None, locals=None, fromlist=(), level=0):
560+
raise ImportError
561+
562+
xr_da = airds.air
563+
clist = list(xr_da.dims)
564+
with monkeypatch.context() as m:
565+
m.setattr(builtins, "__import__", _bad_import)
566+
with pytest.raises(ValueError, match=f"{clist[0]} is not"):
567+
568+
_ = xr2yt._convert_to_yt_internal_coords(clist, xr_da)
569+
570+
571+
def test_coord_alias_reset():
572+
xr2yt.known_coord_aliases["blah"] = "lwkerj"
573+
xr2yt.reset_coordinate_aliases()
574+
assert "blah" not in xr2yt.known_coord_aliases
575+
576+
538577
def test_reader_with_2d_space_time_and_reverse_axis():
539578

540579
# test for https://github.com/data-exp-lab/yt_xarray/issues/86

0 commit comments

Comments
 (0)