Skip to content

Commit 7808fda

Browse files
committed
add correction for tuples + correct error in quantity if tuples outputed
1 parent 54944d6 commit 7808fda

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

gsw_xarray/_core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717

1818
def add_attrs(rv, attrs, name):
19-
rv.name = name
20-
rv.attrs = attrs
19+
if isinstance(rv, xr.DataArray):
20+
rv.name = name
21+
rv.attrs = attrs
2122

2223

2324
def pint_compat(args, kwargs):
@@ -68,14 +69,13 @@ def cf_attrs_wrapper(*args, **kwargs):
6869
for (i, da) in enumerate(rv):
6970
add_attrs(da, attrs_checked[i], name[i])
7071
if is_quantity:
71-
da = rv.pint.quantify()
72-
rv_updated.append(da)
72+
rv_updated.append(da.pint.quantify())
7373
else:
7474
rv_updated.append(da)
7575

7676
rv = tuple(rv_updated)
7777

78-
elif isinstance(rv, xr.DataArray):
78+
else:
7979
add_attrs(rv, attrs_checked, name)
8080
if is_quantity:
8181
rv = rv.pint.quantify()

gsw_xarray/tests/test_gsw_xarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from gsw_xarray import __version__
22
import gsw_xarray as gsw
33
import xarray as xr
4+
import numpy as np
45
import pytest
56

67

@@ -22,6 +23,12 @@ def test_func_return_tuple(ds):
2223
assert CT_SA.attrs["units"] == "K/(g/kg)"
2324

2425

26+
def test_func_return_tuple_ndarray(ds):
27+
(CT_SA, CT_pt) = gsw.CT_first_derivatives(ds.SA.data, 1)
28+
assert isinstance(CT_SA, np.ndarray)
29+
assert isinstance(CT_pt, np.ndarray)
30+
31+
2532
@pytest.mark.parametrize("gsdh", [0, 1, None])
2633
@pytest.mark.parametrize("ssg", [0, 1, None])
2734
@pytest.mark.parametrize(

gsw_xarray/tests/test_units.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,9 @@ def test_xarray_quantity_or_ds(ds, ds_pint, SA_type, CT_type):
7171
else:
7272
assert sigma0.pint.units is None
7373
assert sigma0.pint.quantify().pint.units == pint_xarray.unit_registry("kg / m^3")
74+
75+
76+
def test_func_return_tuple_quantity(ds_pint):
77+
pint_xarray = pytest.importorskip("pint_xarray")
78+
(CT_SA, CT_pt) = gsw_xarray.CT_first_derivatives(ds_pint.SA, 1)
79+
assert CT_SA.pint.units == pint_xarray.unit_registry("K/(g/kg)")

0 commit comments

Comments
 (0)