Skip to content

Commit 92bf1aa

Browse files
committed
add support for pint only quantities
1 parent 1557ec1 commit 92bf1aa

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

gsw_xarray/_core.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
try:
1212
import pint_xarray
1313
import pint
14+
15+
ureg = pint.UnitRegistry()
16+
Q_ = ureg.Quantity
1417
except ImportError:
1518
pint_xarray = None
1619

@@ -20,6 +23,16 @@ def add_attrs(rv, attrs, name):
2023
rv.name = name
2124
rv.attrs = attrs
2225

26+
27+
def quantify(rv, attrs):
28+
if isinstance(rv, xr.DataArray):
29+
rv = rv.pint.quantify()
30+
else:
31+
if attrs is not None:
32+
rv = Q_(rv, attrs["units"])
33+
return rv
34+
35+
2336
def pint_compat(args, kwargs):
2437
if pint_xarray is None:
2538
return args, kwargs, False
@@ -68,16 +81,16 @@ def cf_attrs_wrapper(*args, **kwargs):
6881
for (i, da) in enumerate(rv):
6982
add_attrs(da, attrs_checked[i], name[i])
7083
if is_quantity:
71-
rv_updated.append(da.pint.quantify())
84+
rv_updated.append(quantify(da, attrs_checked[i]))
7285
else:
7386
rv_updated.append(da)
7487

7588
rv = tuple(rv_updated)
76-
89+
7790
else:
7891
add_attrs(rv, attrs_checked, name)
7992
if is_quantity:
80-
rv = rv.pint.quantify()
93+
rv = quantify(rv, attrs_checked)
8194
return rv
8295

8396
return cf_attrs_wrapper

gsw_xarray/tests/test_units.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,65 @@ def test_xarray_quantity(ds_pint):
5050
sigma0 = gsw_xarray.sigma0(SA=ds_pint.SA, CT=ds_pint.CT)
5151
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
5252

53-
@pytest.mark.parametrize("SA_type", ['unit', 'ds'])
54-
@pytest.mark.parametrize("CT_type", ['unit', 'ds'])
53+
54+
@pytest.mark.parametrize("SA_type", ["unit", "ds"])
55+
@pytest.mark.parametrize("CT_type", ["unit", "ds"])
5556
def test_xarray_quantity_or_ds(ds, ds_pint, SA_type, CT_type):
5657
"""If at least 1 of the inputs is quantity, the result should be quantity"""
5758
pint_xarray = pytest.importorskip("pint_xarray")
58-
if SA_type == 'unit':
59+
if SA_type == "unit":
5960
SA = ds_pint.SA
60-
elif SA_type == 'ds':
61+
elif SA_type == "ds":
6162
SA = ds.SA
62-
63-
if CT_type == 'unit':
63+
64+
if CT_type == "unit":
6465
CT = ds_pint.CT
65-
elif CT_type == 'ds':
66+
elif CT_type == "ds":
6667
CT = ds.CT
67-
68+
6869
sigma0 = gsw_xarray.sigma0(SA=SA, CT=CT)
69-
if SA_type == 'unit' or CT_type == 'unit':
70+
if SA_type == "unit" or CT_type == "unit":
7071
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
7172
else:
7273
assert sigma0.pint.units is None
73-
assert sigma0.pint.quantify().pint.units == pint_xarray.unit_registry("kg / m^3")
74+
assert sigma0.pint.quantify().pint.units == pint_xarray.unit_registry(
75+
"kg / m^3"
76+
)
7477

7578

7679
def test_func_return_tuple_quantity(ds_pint):
7780
pint_xarray = pytest.importorskip("pint_xarray")
7881
(CT_SA, CT_pt) = gsw_xarray.CT_first_derivatives(ds_pint.SA, 1)
7982
assert CT_SA.pint.units == pint_xarray.unit_registry("K/(g/kg)")
83+
84+
85+
def test_pint_quantity_xarray(ds):
86+
"""If input is mixed between xr.DataArray and pint quantity it should return pint-xarray wrapped quantity"""
87+
pint_xarray = pytest.importorskip("pint_xarray")
88+
import pint
89+
90+
ureg = pint.UnitRegistry()
91+
Q_ = ureg.Quantity
92+
sigma0 = gsw_xarray.sigma0(SA=ds.SA, CT=Q_(25.4, ureg.degC))
93+
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
94+
95+
96+
def test_pint_quantity():
97+
"""If input is pint quantity should return a quantity"""
98+
pint_xarray = pytest.importorskip("pint_xarray")
99+
import pint
100+
101+
ureg = pint.UnitRegistry()
102+
CT = gsw_xarray.CT_from_pt(SA=35 * ureg("g / kg"), pt=10)
103+
assert isinstance(CT, pint.Quantity)
104+
105+
106+
def test_pint_quantity_tuple():
107+
"""If input is pint quantity should return a quantity"""
108+
pint_xarray = pytest.importorskip("pint_xarray")
109+
import pint
110+
111+
ureg = pint.UnitRegistry()
112+
(a, b) = gsw_xarray.CT_first_derivatives(35 * ureg("g / kg"), pt=1)
113+
assert isinstance(a, pint.Quantity)
114+
assert isinstance(b, pint.Quantity)

0 commit comments

Comments
 (0)