Skip to content

Commit 67a903b

Browse files
authored
Merge pull request #67 from efiring/array_ufunc
Handle xarray DataArray in wrapped ufuncs
2 parents a53c2c2 + 0a82070 commit 67a903b

File tree

5 files changed

+174
-13
lines changed

5 files changed

+174
-13
lines changed

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ before_install:
3636
conda create --name TEST python=$PY --file requirements-dev.txt --quiet
3737
source activate TEST
3838
# Install after to ensure it will be downgraded when testing an older version.
39-
conda install numpy=$NUMPY
39+
conda install numpy=$NUMPY xarray dask
4040
conda info --all
4141
4242
# Test source distribution.
@@ -64,7 +64,7 @@ script:
6464
pushd docs
6565
make clean html linkcheck
6666
popd
67-
if [[ -z "$TRAVIS_TAG" ]]; then
67+
if [[ -z "$TRAVIS_TAG" ]]; then
6868
python -m doctr deploy --build-tags --key-path github_deploy_key.enc --built-docs docs/_build/html dev
6969
else
7070
python -m doctr deploy --build-tags --key-path github_deploy_key.enc --built-docs docs/_build/html "version-$TRAVIS_TAG"

gsw/_utilities.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,35 @@ def wrapper(*args, **kw):
2626
args = list(args)
2727
args.append(p)
2828

29-
isarray = np.any([hasattr(a, '__iter__') for a in args])
30-
ismasked = np.any([np.ma.isMaskedArray(a) for a in args])
29+
isarray = [hasattr(a, '__iter__') for a in args]
30+
ismasked = [np.ma.isMaskedArray(a) for a in args]
31+
isduck = [hasattr(a, '__array_ufunc__')
32+
and not isinstance(a, np.ndarray) for a in args]
33+
34+
hasarray = np.any(isarray)
35+
hasmasked = np.any(ismasked)
36+
hasduck = np.any(isduck)
3137

3238
def fixup(ret):
33-
if ismasked:
39+
if hasduck:
40+
return ret
41+
if hasmasked:
3442
ret = np.ma.masked_invalid(ret)
35-
if not isarray and isinstance(ret, np.ndarray):
36-
ret = ret[0]
43+
if not hasarray and isinstance(ret, np.ndarray) and ret.size == 1:
44+
try:
45+
ret = ret[0]
46+
except IndexError:
47+
pass
3748
return ret
3849

39-
if ismasked:
40-
newargs = [masked_to_nan(a) for a in args]
41-
else:
42-
newargs = [np.asarray(a, dtype=float) for a in args]
50+
newargs = []
51+
for i, arg in enumerate(args):
52+
if ismasked[i]:
53+
newargs.append(masked_to_nan(arg))
54+
elif isduck[i]:
55+
newargs.append(arg)
56+
else:
57+
newargs.append(np.asarray(arg, dtype=float))
4358

4459
if p is not None:
4560
kw['p'] = newargs.pop()

gsw/geostrophy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def geo_strf_dyn_height(SA, CT, p, p_ref=0, axis=0, max_dp=1.0,
7373
dh = np.empty(SA.shape, dtype=float)
7474
dh.fill(np.nan)
7575

76-
order = 'F' if SA.flags.fortran else 'C'
76+
try:
77+
order = 'F' if SA.flags.fortran else 'C'
78+
except AttributeError:
79+
order = 'C' # e.g., xarray DataArray doesn't have flags
7780
for ind in indexer(SA.shape, axis, order=order):
7881
igood = goodmask[ind]
7982
# If p_ref is below the deepest value, skip the profile.

gsw/tests/check_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def find(x):
5555
"""
5656
Numpy equivalent to Matlab find.
5757
"""
58-
return np.nonzero(x.flatten())[0]
58+
return np.nonzero(np.asarray(x).flatten())[0]
5959

6060

6161
def group_or(line):

gsw/tests/test_xarray.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
Tests functions with xarray inputs.
3+
4+
This version is a copy of the original test_check_functions but with
5+
an import of xarray, and conversion of the 3 main check cast arrays
6+
into DataArray objects.
7+
8+
An additional xarray-dask test is added.
9+
"""
10+
11+
import os
12+
import pytest
13+
14+
import numpy as np
15+
from numpy.testing import assert_allclose
16+
17+
import gsw
18+
from gsw._utilities import Bunch
19+
from check_functions import parse_check_functions
20+
21+
xr = pytest.importorskip('xarray')
22+
23+
# Most of the tests have some nan values, so we need to suppress the warning.
24+
# Any more careful fix would likely require considerable effort.
25+
np.seterr(invalid='ignore')
26+
27+
root_path = os.path.abspath(os.path.dirname(__file__))
28+
29+
# Function checks that we can't handle automatically yet.
30+
blacklist = ['deltaSA_atlas', # the test is complicated; doesn't fit the pattern.
31+
'geostrophic_velocity', # test elsewhere; we changed the API
32+
#'CT_from_entropy', # needs prior entropy_from_CT; don't have it in C
33+
#'CT_first_derivatives', # passes, but has trouble in "details";
34+
# see check_functions.py
35+
#'entropy_second_derivatives', # OK now; handling extra parens.
36+
#'melting_ice_into_seawater', # OK now; fixed nargs mismatch.
37+
]
38+
39+
# We get an overflow from ct_from_enthalpy_exact, but the test passes.
40+
cv = Bunch(np.load(os.path.join(root_path, 'gsw_cv_v3_0.npz')))
41+
42+
# Substitute new check values for the pchip interpolation version.
43+
cv.geo_strf_dyn_height = np.load(os.path.join(root_path,'geo_strf_dyn_height.npy'))
44+
cv.geo_strf_velocity = np.load(os.path.join(root_path,'geo_strf_velocity.npy'))
45+
46+
for name in ['SA_chck_cast', 't_chck_cast', 'p_chck_cast']:
47+
cv[name] = xr.DataArray(cv[name])
48+
49+
cf = Bunch()
50+
51+
d = dir(gsw)
52+
funcnames = [name for name in d if '__' not in name]
53+
54+
mfuncs = parse_check_functions(os.path.join(root_path, 'gsw_check_functions_save.m'))
55+
mfuncs = [mf for mf in mfuncs if mf.name in d and mf.name not in blacklist]
56+
mfuncnames = [mf.name for mf in mfuncs]
57+
58+
59+
@pytest.fixture(scope='session', params=mfuncs)
60+
def cfcf(request):
61+
return cv, cf, request.param
62+
63+
64+
def test_check_function(cfcf):
65+
cv, cf, mfunc = cfcf
66+
mfunc.run(locals())
67+
if mfunc.exception is not None or not mfunc.passed:
68+
print('\n', mfunc.name)
69+
print(' ', mfunc.runline)
70+
print(' ', mfunc.testline)
71+
if mfunc.exception is None:
72+
mfunc.exception = ValueError('Calculated values are different from the expected matlab results.')
73+
raise mfunc.exception
74+
else:
75+
print(mfunc.name)
76+
assert mfunc.passed
77+
78+
79+
def test_dask_chunking():
80+
dsa = pytest.importorskip('dask.array')
81+
82+
# define some input data
83+
shape = (100, 1000)
84+
chunks = (100, 200)
85+
sp = xr.DataArray(dsa.full(shape, 35., chunks=chunks), dims=['time', 'depth'])
86+
p = xr.DataArray(np.arange(shape[1]), dims=['depth'])
87+
lon = 0
88+
lat = 45
89+
90+
sa = gsw.SA_from_SP(sp, p, lon, lat)
91+
sa_dask = sa.compute()
92+
93+
sa_numpy = gsw.SA_from_SP(np.full(shape, 35.0), p.values, lon, lat)
94+
assert_allclose(sa_dask, sa_numpy)
95+
96+
97+
# Additional tests from Graeme MacGilchrist
98+
# https://nbviewer.jupyter.org/github/gmacgilchrist/wmt_bgc/blob/master/notebooks/test_gsw-xarray.ipynb
99+
100+
# Define dimensions and coordinates
101+
dims = ['y','z','t']
102+
# 2x2x2
103+
y = np.arange(0,2)
104+
z = np.arange(0,2)
105+
t = np.arange(0,2)
106+
# Define numpy arrays of salinity, temperature and pressure
107+
SA_vals = np.array([[[34.7,34.8],[34.9,35]],[[35.1,35.2],[35.3,35.4]]])
108+
CT_vals = np.array([[[7,8],[9,10]],[[11,12],[13,14]]])
109+
p_vals = np.array([10,20])
110+
lat_vals = np.array([0,10])
111+
# Plug in to xarray objects
112+
SA = xr.DataArray(SA_vals,dims=dims,coords={'y':y,'z':z,'t':t})
113+
CT = xr.DataArray(CT_vals,dims=dims,coords={'y':y,'z':z,'t':t})
114+
p = xr.DataArray(p_vals,dims=['z'],coords={'z':z})
115+
lat = xr.DataArray(lat_vals,dims=['y'],coords={'y':y})
116+
117+
118+
def test_xarray_with_coords():
119+
pytest.importorskip('dask')
120+
SA_chunk = SA.chunk(chunks={'y':1,'t':1})
121+
CT_chunk = CT.chunk(chunks={'y':1,'t':1})
122+
lat_chunk = lat.chunk(chunks={'y':1})
123+
124+
# Dimensions and cordinates match:
125+
expected = gsw.sigma0(SA_vals, CT_vals)
126+
xarray = gsw.sigma0(SA, CT)
127+
chunked = gsw.sigma0(SA_chunk, CT_chunk)
128+
assert_allclose(xarray, expected)
129+
assert_allclose(chunked, expected)
130+
131+
# Broadcasting along dimension required (dimensions known)
132+
expected = gsw.alpha(SA_vals, CT_vals, p_vals[np.newaxis, :, np.newaxis])
133+
xarray = gsw.alpha(SA, CT, p)
134+
chunked = gsw.alpha(SA_chunk, CT_chunk, p)
135+
assert_allclose(xarray, expected)
136+
assert_allclose(chunked, expected)
137+
138+
# Broadcasting along dimension required (dimensions unknown/exclusive)
139+
expected = gsw.z_from_p(p_vals[:, np.newaxis], lat_vals[np.newaxis, :])
140+
xarray = gsw.z_from_p(p, lat)
141+
chunked = gsw.z_from_p(p,lat_chunk)
142+
assert_allclose(xarray, expected)
143+
assert_allclose(chunked, expected)

0 commit comments

Comments
 (0)