Skip to content

Commit 7ad0a6a

Browse files
Add support usm_ndarray for condlist
1 parent 54e255e commit 7ad0a6a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

dpnp/dpnp_iface_functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def piecewise(x, condlist, funclist):
336336
"""
337337
dpnp.check_supported_arrays_type(x)
338338
x_dtype = x.dtype
339-
if isinstance(condlist, dpnp.ndarray) and condlist.ndim in [0, 1]:
339+
if dpnp.is_supported_array_type(condlist) and condlist.ndim in [0, 1]:
340340
condlist = [condlist]
341341
elif dpnp.isscalar(condlist) or (
342342
dpnp.isscalar(condlist[0]) and x.ndim != 0
@@ -348,7 +348,7 @@ def piecewise(x, condlist, funclist):
348348
x.shape, condlist, usm_type=x.usm_type, sycl_queue=x.sycl_queue
349349
)
350350
]
351-
elif not isinstance(condlist[0], (dpnp.ndarray)):
351+
elif not dpnp.is_supported_array_type(condlist[0]):
352352
# convert list of lists to list of arrays
353353
# convert list of scalars to a list of 0d arrays (for 0d input)
354354
tmp = []
@@ -369,7 +369,7 @@ def piecewise(x, condlist, funclist):
369369
condlen = len(condlist)
370370
try:
371371
if isinstance(funclist, str):
372-
raise TypeError
372+
raise TypeError("funclist must be a non-string sequence")
373373
funclen = len(funclist)
374374
except TypeError as e:
375375
raise TypeError("funclist must be a sequence of scalars") from e

0 commit comments

Comments
 (0)