Skip to content

Commit 3b37096

Browse files
authored
Merge pull request #41 from DocOtak/pint
Add direct pint/pint-xarray support
2 parents 99aff72 + 05cb8b8 commit 3b37096

File tree

8 files changed

+521
-198
lines changed

8 files changed

+521
-198
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
poetry-version: ${{ matrix.poetry-version }}
6464
- name: Install dependencies
6565
run: |
66-
poetry install
66+
poetry install -E pint
6767
poetry run pip install cf_units==3.0.1
6868
- name: Test with pytest
6969
run: |

CHANGELOG.rst

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
Changelog
22
=========
3+
4+
v0.3.0 - Unreleased
5+
-------------------
6+
This release will focus on supporting Pint quantities.
7+
8+
Highlights
9+
``````````
10+
* If (at least 1) arguments are xarray.DataArray wrapped by pint-xarray, the result is wrapped into a pint-xarray quantity.
11+
312
v0.2.1 - 2022-03-22
413
-------------------
514
Despite all the checking, we missed a bad bug.
@@ -37,4 +46,4 @@ Breaking Changes
3746
v0.1.0 - 2021-12-15
3847
-------------------
3948
* Original release, was basically a proof of concept.
40-
Only a few functions were wrapped.
49+
Only a few functions were wrapped.

README.rst

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,69 @@ Outputs
6161

6262
<class 'numpy.ndarray'> [-5.08964499 2.1101098 9.28348219]
6363

64+
65+
We support (but don't yet validate) the usage of pint.Quantities and the usage of xarray wrapped Quantities.
66+
Support for pint requires the installation of two optional dependencies: ``pint`` and ``pint-xarray``.
67+
If any of the inputs to a gsw function are Quantities, the returned object will also be a Quantity belonging to the same UnitRegistry.
68+
69+
.. warning::
70+
71+
Quantities must all belong to the same pint.UnitRegistry, a ValueError will be thrown if there are mixed registries.
72+
73+
.. code:: python
74+
75+
import pint_xarray
76+
import gsw_xarray as gsw
77+
78+
# Create a xarray.Dataset
79+
import numpy as np
80+
import xarray as xr
81+
ds = xr.Dataset()
82+
id = np.arange(3)
83+
ds['id'] = xr.DataArray(id, coords={'id':id})
84+
ds['CT'] = ds['id'] * 10
85+
# make sure there are unit attrs this time
86+
ds['CT'].attrs = {'standard_name':'sea_water_conservative_temperature', 'units': 'degC'}
87+
ds['SA'] = ds['id'] * 0.1 + 34
88+
ds['SA'].attrs = {'standard_name':'sea_water_absolute_salinity', 'units': 'g/kg'}
89+
90+
# use the pint accessor to quantify things
91+
ds = ds.pint.quantify()
92+
93+
# Apply gsw functions
94+
sigma0 = gsw.sigma0(SA=ds['SA'], CT=ds['CT'])
95+
# outputs are now quantities!
96+
print(sigma0)
97+
98+
Outputs
99+
100+
::
101+
102+
<xarray.DataArray 'sigma0' (id: 3)>
103+
<Quantity([27.17191038 26.12820162 24.03930887], 'kilogram / meter ** 3')>
104+
Coordinates:
105+
* id (id) int64 0 1 2
106+
Attributes:
107+
standard_name: sea_water_sigma_t
108+
109+
The usage of xarray wrapped Quantities is not required, you can use pint directly (though the ``pint-xarray`` dep still needs to be installed).
110+
111+
.. code:: python
112+
113+
import gsw_xarray as gsw
114+
import pint
115+
ureg = pint.UnitRegistry()
116+
SA = ureg.Quantity(35, ureg("g/kg"))
117+
CT = ureg.Quantity(10, ureg.degC)
118+
sigma0 = gsw.sigma0(SA=SA, CT=CT)
119+
print(sigma0)
120+
121+
Outputs
122+
123+
::
124+
125+
26.824644457868317 kilogram / meter ** 3
126+
64127
Installation
65128
------------
66129
Pip

gsw_xarray/_core.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import wraps
2+
from itertools import chain
23

34
import gsw
45
import xarray as xr
@@ -7,24 +8,98 @@
78
from ._names import _names
89
from ._check_funcs import _check_funcs
910

11+
try:
12+
import pint_xarray
13+
import pint
14+
15+
except ImportError:
16+
pint_xarray = None
17+
1018

1119
def add_attrs(rv, attrs, name):
1220
if isinstance(rv, xr.DataArray):
1321
rv.name = name
1422
rv.attrs = attrs
1523

1624

25+
def quantify(rv, attrs, unit_registry=None):
26+
if unit_registry is None:
27+
return rv
28+
29+
if isinstance(rv, xr.DataArray):
30+
rv = rv.pint.quantify(unit_registry=unit_registry)
31+
else:
32+
if attrs is not None:
33+
# Necessary to use the Q_ and not simply multiplication with ureg unit because of temperature
34+
# see https://pint.readthedocs.io/en/latest/nonmult.html
35+
rv = unit_registry.Quantity(rv, attrs["units"])
36+
return rv
37+
38+
39+
def pint_compat(args, kwargs):
40+
if pint_xarray is None:
41+
return args, kwargs, None
42+
43+
using_pint = False
44+
new_args = []
45+
new_kwargs = {}
46+
registries = []
47+
for arg in args:
48+
if isinstance(arg, xr.DataArray):
49+
if arg.pint.units is not None:
50+
new_args.append(arg.pint.dequantify())
51+
registries.append(arg.pint.registry)
52+
else:
53+
new_args.append(arg)
54+
elif isinstance(arg, pint.Quantity):
55+
new_args.append(arg.magnitude)
56+
registries.append(arg._REGISTRY)
57+
else:
58+
new_args.append(arg)
59+
60+
for kw, arg in kwargs.items():
61+
if isinstance(arg, xr.DataArray):
62+
if arg.pint.units is not None:
63+
new_kwargs[kw] = arg.pint.dequantify()
64+
registries.append(arg.pint.registry)
65+
else:
66+
new_kwargs[kw] = arg
67+
elif isinstance(arg, pint.Quantity):
68+
new_kwargs[kw] = arg.magnitude
69+
registries.append(arg._REGISTRY)
70+
else:
71+
new_kwargs[kw] = arg
72+
73+
registries = set(registries)
74+
if len(registries) > 1:
75+
raise ValueError("Quantity arguments must all belong to the same unit registry")
76+
elif len(registries) == 0:
77+
registries = None
78+
else:
79+
(registries,) = registries
80+
return new_args, new_kwargs, registries
81+
82+
1783
def cf_attrs(attrs, name, check_func):
1884
def cf_attrs_decorator(func):
1985
@wraps(func)
2086
def cf_attrs_wrapper(*args, **kwargs):
87+
args, kwargs, unit_registry = pint_compat(args, kwargs)
2188
rv = func(*args, **kwargs)
2289
attrs_checked = check_func(attrs, args, kwargs)
2390
if isinstance(rv, tuple):
91+
rv_updated = []
2492
for (i, da) in enumerate(rv):
2593
add_attrs(da, attrs_checked[i], name[i])
94+
rv_updated.append(
95+
quantify(da, attrs_checked[i], unit_registry=unit_registry)
96+
)
97+
98+
rv = tuple(rv_updated)
99+
26100
else:
27101
add_attrs(rv, attrs_checked, name)
102+
rv = quantify(rv, attrs_checked, unit_registry=unit_registry)
28103
return rv
29104

30105
return cf_attrs_wrapper

gsw_xarray/tests/conftest.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,23 @@ def ds():
1010
id = np.arange(3)
1111
ds["id"] = xr.DataArray(id, coords={"id": id})
1212
ds["CT"] = ds["id"] * 10
13-
ds["CT"].attrs = {"standard_name": "sea_water_conservative_temperature"}
13+
ds["CT"].attrs = {
14+
"standard_name": "sea_water_conservative_temperature",
15+
"units": "degC",
16+
}
1417
ds["SA"] = ds["id"] * 0.1 + 34
15-
ds["SA"].attrs = {"standard_name": "sea_water_absolute_salinity"}
18+
ds["SA"].attrs = {"standard_name": "sea_water_absolute_salinity", "units": "g/kg"}
1619
return ds
1720

1821

1922
@pytest.fixture(scope="session")
2023
def ureg():
2124
pint = pytest.importorskip("pint")
2225
return pint.UnitRegistry()
26+
27+
28+
@pytest.fixture
29+
def ds_pint(ds, ureg):
30+
pytest.importorskip("pint_xarray")
31+
32+
return ds.pint.quantify()

gsw_xarray/tests/test_units.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44
import pytest
55

6+
import gsw_xarray
7+
68
from .test_imports import gsw_base
79
from gsw_xarray._attributes import _func_attrs
810

@@ -41,3 +43,84 @@ def test_unit_cf_units(func_name):
4143
for a in attrs:
4244
print(a["units"])
4345
cf_units.Unit(a["units"])
46+
47+
48+
def test_xarray_quantity(ds_pint):
49+
pint_xarray = pytest.importorskip("pint_xarray")
50+
sigma0 = gsw_xarray.sigma0(SA=ds_pint.SA, CT=ds_pint.CT)
51+
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
52+
53+
54+
@pytest.mark.parametrize("SA_type", ["unit", "ds"])
55+
@pytest.mark.parametrize("CT_type", ["unit", "ds"])
56+
def test_xarray_quantity_or_ds(ds, ds_pint, SA_type, CT_type):
57+
"""If at least 1 of the inputs is quantity, the result should be quantity"""
58+
pint_xarray = pytest.importorskip("pint_xarray")
59+
if SA_type == "unit":
60+
SA = ds_pint.SA
61+
elif SA_type == "ds":
62+
SA = ds.SA
63+
64+
if CT_type == "unit":
65+
CT = ds_pint.CT
66+
elif CT_type == "ds":
67+
CT = ds.CT
68+
69+
sigma0 = gsw_xarray.sigma0(SA=SA, CT=CT)
70+
if SA_type == "unit" or CT_type == "unit":
71+
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
72+
else:
73+
assert sigma0.pint.units is None
74+
assert sigma0.pint.quantify().pint.units == pint_xarray.unit_registry(
75+
"kg / m^3"
76+
)
77+
78+
79+
def test_func_return_tuple_quantity(ds_pint):
80+
pint_xarray = pytest.importorskip("pint_xarray")
81+
(CT_SA, CT_pt) = gsw_xarray.CT_first_derivatives(ds_pint.SA, 1)
82+
assert CT_SA.pint.units == pint_xarray.unit_registry("K/(g/kg)")
83+
84+
85+
def test_pint_quantity_xarray(ds):
86+
"""If input is mixed between xr.DataArray and pint quantity it should return pint-xarray wrapped quantity"""
87+
pint_xarray = pytest.importorskip("pint_xarray")
88+
89+
ureg = pint_xarray.unit_registry
90+
Q_ = ureg.Quantity
91+
sigma0 = gsw_xarray.sigma0(SA=ds.SA, CT=Q_(25.4, ureg.degC))
92+
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
93+
94+
95+
def test_pint_quantity():
96+
"""If input is pint quantity should return a quantity"""
97+
pint_xarray = pytest.importorskip("pint_xarray")
98+
import pint
99+
100+
ureg = pint_xarray.unit_registry
101+
CT = gsw_xarray.CT_from_pt(SA=35 * ureg("g / kg"), pt=10)
102+
assert isinstance(CT, pint.Quantity)
103+
104+
105+
def test_pint_quantity_tuple():
106+
"""If input is pint quantity should return a quantity"""
107+
pint_xarray = pytest.importorskip("pint_xarray")
108+
import pint
109+
110+
ureg = pint_xarray.unit_registry
111+
(a, b) = gsw_xarray.CT_first_derivatives(35 * ureg("g / kg"), pt=1)
112+
assert isinstance(a, pint.Quantity)
113+
assert isinstance(b, pint.Quantity)
114+
115+
116+
def test_mixed_unit_regestiries():
117+
"""If input quantities are from different registries, it should fail"""
118+
pint_xarray = pytest.importorskip("pint_xarray")
119+
import pint
120+
121+
ureg_a = pint.UnitRegistry()
122+
ureg_b = pint.UnitRegistry()
123+
with pytest.raises(ValueError):
124+
gsw_xarray.CT_first_derivatives(
125+
35 * ureg_a("g / kg"), pt=ureg_b.Quantity(1, ureg_b.degC)
126+
)

0 commit comments

Comments
 (0)