Skip to content

Commit 6de9f79

Browse files
committed
Handle xarray DataArray in wrapped ufuncs
1 parent a53c2c2 commit 6de9f79

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

gsw/_utilities.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,35 @@ def wrapper(*args, **kw):
2626
args = list(args)
2727
args.append(p)
2828

29-
isarray = np.any([hasattr(a, '__iter__') for a in args])
30-
ismasked = np.any([np.ma.isMaskedArray(a) for a in args])
29+
isarray = [hasattr(a, '__iter__') for a in args]
30+
ismasked = [np.ma.isMaskedArray(a) for a in args]
31+
isduck = [hasattr(a, '__array_ufunc__')
32+
and not isinstance(a, np.ndarray) for a in args]
33+
34+
hasarray = np.any(isarray)
35+
hasmasked = np.any(ismasked)
36+
hasduck = np.any(isduck)
3137

3238
def fixup(ret):
33-
if ismasked:
39+
if hasduck:
40+
return ret
41+
if hasmasked:
3442
ret = np.ma.masked_invalid(ret)
35-
if not isarray and isinstance(ret, np.ndarray):
36-
ret = ret[0]
43+
if not hasarray and isinstance(ret, np.ndarray) and ret.size == 1:
44+
try:
45+
ret = ret[0]
46+
except IndexError:
47+
pass
3748
return ret
3849

39-
if ismasked:
40-
newargs = [masked_to_nan(a) for a in args]
41-
else:
42-
newargs = [np.asarray(a, dtype=float) for a in args]
50+
newargs = []
51+
for i, arg in enumerate(args):
52+
if ismasked[i]:
53+
newargs.append(masked_to_nan(arg))
54+
elif isduck[i]:
55+
newargs.append(arg)
56+
else:
57+
newargs.append(np.asarray(arg, dtype=float))
4358

4459
if p is not None:
4560
kw['p'] = newargs.pop()

0 commit comments

Comments
 (0)