Skip to content

Commit 4cf63c9

Browse files
author
Andrew Barna
committed
Get pint unit registries from the input Quantities
1 parent 45b5a97 commit 4cf63c9

File tree

2 files changed

+41
-22
lines changed

2 files changed

+41
-22
lines changed

gsw_xarray/_core.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import pint_xarray
1313
import pint
1414

15-
ureg = pint.UnitRegistry()
16-
Q_ = ureg.Quantity
1715
except ImportError:
1816
pint_xarray = None
1917

@@ -24,75 +22,84 @@ def add_attrs(rv, attrs, name):
2422
rv.attrs = attrs
2523

2624

27-
def quantify(rv, attrs):
25+
def quantify(rv, attrs, unit_registry=None):
26+
if unit_registry is None:
27+
return rv
28+
2829
if isinstance(rv, xr.DataArray):
29-
rv = rv.pint.quantify()
30+
rv = rv.pint.quantify(unit_registry=unit_registry)
3031
else:
3132
if attrs is not None:
3233
# Necessary to use the Q_ and not simply multiplication with ureg unit because of temperature
3334
# see https://pint.readthedocs.io/en/latest/nonmult.html
34-
rv = Q_(rv, attrs["units"])
35+
rv = unit_registry.Quantity(rv, attrs["units"])
3536
return rv
3637

3738

3839
def pint_compat(args, kwargs):
3940
if pint_xarray is None:
40-
return args, kwargs, False
41+
return args, kwargs, None
4142

4243
using_pint = False
4344
new_args = []
4445
new_kwargs = {}
46+
registries = []
4547
for arg in args:
4648
if isinstance(arg, xr.DataArray):
4749
if arg.pint.units is not None:
4850
new_args.append(arg.pint.dequantify())
49-
using_pint = True
51+
registries.append(arg.pint.registry)
5052
else:
5153
new_args.append(arg)
5254
elif isinstance(arg, pint.Quantity):
5355
new_args.append(arg.magnitude)
54-
using_pint = True
56+
registries.append(arg._REGISTRY)
5557
else:
5658
new_args.append(arg)
5759

5860
for kw, arg in kwargs.items():
5961
if isinstance(arg, xr.DataArray):
6062
if arg.pint.units is not None:
6163
new_kwargs[kw] = arg.pint.dequantify()
62-
using_pint = True
64+
registries.append(arg.pint.registry)
6365
else:
6466
new_kwargs[kw] = arg
6567
elif isinstance(arg, pint.Quantity):
6668
new_kwargs[kw] = arg.magnitude
67-
using_pint = True
69+
registries.append(arg._REGISTRY)
6870
else:
6971
new_kwargs[kw] = arg
7072

71-
return new_args, new_kwargs, using_pint
73+
registries = set(registries)
74+
if len(registries) > 1:
75+
raise ValueError("Quantity arguments must all belong to the same unit registry")
76+
elif len(registries) == 0:
77+
registries = None
78+
else:
79+
(registries,) = registries
80+
return new_args, new_kwargs, registries
7281

7382

7483
def cf_attrs(attrs, name, check_func):
7584
def cf_attrs_decorator(func):
7685
@wraps(func)
7786
def cf_attrs_wrapper(*args, **kwargs):
78-
args, kwargs, is_quantity = pint_compat(args, kwargs)
87+
args, kwargs, unit_registry = pint_compat(args, kwargs)
7988
rv = func(*args, **kwargs)
8089
attrs_checked = check_func(attrs, args, kwargs)
8190
if isinstance(rv, tuple):
8291
rv_updated = []
8392
for (i, da) in enumerate(rv):
8493
add_attrs(da, attrs_checked[i], name[i])
85-
if is_quantity:
86-
rv_updated.append(quantify(da, attrs_checked[i]))
87-
else:
88-
rv_updated.append(da)
94+
rv_updated.append(
95+
quantify(da, attrs_checked[i], unit_registry=unit_registry)
96+
)
8997

9098
rv = tuple(rv_updated)
9199

92100
else:
93101
add_attrs(rv, attrs_checked, name)
94-
if is_quantity:
95-
rv = quantify(rv, attrs_checked)
102+
rv = quantify(rv, attrs_checked, unit_registry=unit_registry)
96103
return rv
97104

98105
return cf_attrs_wrapper

gsw_xarray/tests/test_units.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ def test_func_return_tuple_quantity(ds_pint):
8585
def test_pint_quantity_xarray(ds):
8686
"""If input is mixed between xr.DataArray and pint quantity it should return pint-xarray wrapped quantity"""
8787
pint_xarray = pytest.importorskip("pint_xarray")
88-
import pint
8988

90-
ureg = pint.UnitRegistry()
89+
ureg = pint_xarray.unit_registry
9190
Q_ = ureg.Quantity
9291
sigma0 = gsw_xarray.sigma0(SA=ds.SA, CT=Q_(25.4, ureg.degC))
9392
assert sigma0.pint.units == pint_xarray.unit_registry("kg / m^3")
@@ -98,7 +97,7 @@ def test_pint_quantity():
9897
pint_xarray = pytest.importorskip("pint_xarray")
9998
import pint
10099

101-
ureg = pint.UnitRegistry()
100+
ureg = pint_xarray.unit_registry
102101
CT = gsw_xarray.CT_from_pt(SA=35 * ureg("g / kg"), pt=10)
103102
assert isinstance(CT, pint.Quantity)
104103

@@ -108,7 +107,20 @@ def test_pint_quantity_tuple():
108107
pint_xarray = pytest.importorskip("pint_xarray")
109108
import pint
110109

111-
ureg = pint.UnitRegistry()
110+
ureg = pint_xarray.unit_registry
112111
(a, b) = gsw_xarray.CT_first_derivatives(35 * ureg("g / kg"), pt=1)
113112
assert isinstance(a, pint.Quantity)
114113
assert isinstance(b, pint.Quantity)
114+
115+
116+
def test_mixed_unit_regestiries():
117+
"""If input quantities are from different registries, it should fail"""
118+
pint_xarray = pytest.importorskip("pint_xarray")
119+
import pint
120+
121+
ureg_a = pint.UnitRegistry()
122+
ureg_b = pint.UnitRegistry()
123+
with pytest.raises(ValueError):
124+
gsw_xarray.CT_first_derivatives(
125+
35 * ureg_a("g / kg"), pt=ureg_b.Quantity(1, ureg_b.degC)
126+
)

0 commit comments

Comments
 (0)