Skip to content

Commit 33262e3

Browse files
authored
Merge pull request scipy#22061 from ev-br/ndimage_array_scalars
BUG: ndimage: convert array scalars on return
2 parents 3ef4305 + e2ae3f2 commit 33262e3

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

scipy/ndimage/_support_alternative_backends.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313
MODULE_NAME = 'ndimage'
1414

1515

16+
def _maybe_convert_arg(arg, xp):
17+
"""Convert arrays/scalars hiding in the sequence `arg`."""
18+
if isinstance(arg, (np.ndarray, np.generic)):
19+
return xp.asarray(arg)
20+
elif isinstance(arg, (list, tuple)):
21+
return type(arg)(_maybe_convert_arg(x, xp) for x in arg)
22+
else:
23+
return arg
24+
25+
1626
def delegate_xp(delegator, module_name):
1727
def inner(func):
1828
@functools.wraps(func)
@@ -52,10 +62,7 @@ def wrapper(*args, **kwds):
5262
return result
5363
else:
5464
# lists/tuples
55-
return type(result)(
56-
xp.asarray(x) if isinstance(x, np.ndarray) else x
57-
for x in result
58-
)
65+
return _maybe_convert_arg(result, xp)
5966
return wrapper
6067
return inner
6168

0 commit comments

Comments
 (0)