Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
poetry-version: ${{ matrix.poetry-version }}
- name: Install dependencies
run: |
poetry install
poetry install -E pint
poetry run pip install cf_units==3.0.1
- name: Test with pytest
run: |
Expand Down
55 changes: 55 additions & 0 deletions gsw_xarray/_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import wraps
from itertools import chain

import gsw
import xarray as xr
Expand All @@ -7,23 +8,77 @@
from ._names import _names
from ._check_funcs import _check_funcs

try:
import pint_xarray
import pint
except ImportError:
pint_xarray = None


def add_attrs(rv, attrs, name):
rv.name = name
rv.attrs = attrs


def pint_compat(args, kwargs):
if pint_xarray is None:
return args, kwargs, False

using_pint = False
new_args = []
new_kwargs = {}
for arg in args:
if isinstance(arg, xr.DataArray):
if arg.pint.units is not None:
new_args.append(arg.pint.dequantify())
using_pint = True
else:
new_args.append(arg)
elif isinstance(arg, pint.Quantity):
new_args.append(arg.magnitude)
using_pint = True
else:
new_args.append(arg)

for kw, arg in kwargs.items():
if isinstance(arg, xr.DataArray):
if arg.pint.units is not None:
new_kwargs[kw] = arg.pint.dequantify()
using_pint = True
else:
new_kwargs[kw] = arg
elif isinstance(arg, pint.Quantity):
new_kwargs[kw] = arg.magnitude
using_pint = True
else:
new_kwargs[kw] = arg

return new_args, new_kwargs, using_pint


def cf_attrs(attrs, name, check_func):
def cf_attrs_decorator(func):
@wraps(func)
def cf_attrs_wrapper(*args, **kwargs):
args, kwargs, is_quantity = pint_compat(args, kwargs)
rv = func(*args, **kwargs)
attrs_checked = check_func(attrs, args, kwargs)
if isinstance(rv, tuple):
rv_updated = []
for (i, da) in enumerate(rv):
add_attrs(da, attrs_checked[i], name[i])
if is_quantity:
da = rv.pint.quantify()
rv_updated.append(da)
else:
rv_updated.append(da)

rv = tuple(rv_updated)

elif isinstance(rv, xr.DataArray):
add_attrs(rv, attrs_checked, name)
if is_quantity:
rv = rv.pint.quantify()
return rv

return cf_attrs_wrapper
Expand Down
14 changes: 12 additions & 2 deletions gsw_xarray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@ def ds():
id = np.arange(3)
ds["id"] = xr.DataArray(id, coords={"id": id})
ds["CT"] = ds["id"] * 10
ds["CT"].attrs = {"standard_name": "sea_water_conservative_temperature"}
ds["CT"].attrs = {
"standard_name": "sea_water_conservative_temperature",
"units": "degC",
}
ds["SA"] = ds["id"] * 0.1 + 34
ds["SA"].attrs = {"standard_name": "sea_water_absolute_salinity"}
ds["SA"].attrs = {"standard_name": "sea_water_absolute_salinity", "units": "g/kg"}
return ds


@pytest.fixture(scope="session")
def ureg():
pint = pytest.importorskip("pint")
return pint.UnitRegistry()


@pytest.fixture
def ds_pint(ds, ureg):
pytest.importorskip("pint_xarray")

return ds.pint.quantify()
8 changes: 8 additions & 0 deletions gsw_xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
import pytest

import gsw_xarray

from .test_imports import gsw_base
from gsw_xarray._attributes import _func_attrs

Expand Down Expand Up @@ -41,3 +43,9 @@ def test_unit_cf_units(func_name):
for a in attrs:
print(a["units"])
cf_units.Unit(a["units"])


def test_xarray_quantity(ds_pint):
pint_xarray = pytest.importorskip("pint_xarray")
sigma0 = gsw_xarray.sigma0(SA=ds_pint.SA, CT=ds_pint.CT)
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
Loading