Skip to content

Commit 1420045

Browse files
committed
improve coverage
1 parent dfe601f commit 1420045

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

dpnp/dpnp_iface_functional.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,7 @@ def piecewise(x, condlist, funclist):
338338
339339
"""
340340
dpnp.check_supported_arrays_type(x)
341-
if isinstance(condlist, tuple):
342-
condlist = list(condlist)
343-
elif isinstance(condlist, dpnp.ndarray) and condlist.ndim in [0, 1]:
341+
if isinstance(condlist, dpnp.ndarray) and condlist.ndim in [0, 1]:
344342
condlist = [condlist]
345343
elif dpnp.isscalar(condlist) or (
346344
dpnp.isscalar(condlist[0]) and x.ndim != 0
@@ -352,7 +350,7 @@ def piecewise(x, condlist, funclist):
352350
x.shape, condlist, usm_type=x.usm_type, sycl_queue=x.sycl_queue
353351
)
354352
]
355-
if not isinstance(condlist[0], (dpnp.ndarray)):
353+
elif not isinstance(condlist[0], (dpnp.ndarray)):
356354
# convert list of lists to list of arrays
357355
# convert list of scalars to a list of 0d arrays (for 0d input)
358356
tmp = []

dpnp/tests/test_functional.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,17 @@ def test_simple(self):
136136
result = dpnp.piecewise(ia, [dpnp.array([1, 0])], [1])
137137
assert_array_equal(result, expected)
138138

139-
def test_error(self):
139+
# List of conditions: single bool tuple
140+
expected = numpy.piecewise(a, ([True, False], [False, True]), [1, -4])
141+
result = dpnp.piecewise(ia, ([True, False], [False, True]), [1, -4])
142+
assert_array_equal(result, expected)
143+
144+
# Condition is single bool tuple
145+
expected = numpy.piecewise(a, (True, False), [1])
146+
result = dpnp.piecewise(ia, (True, False), [1])
147+
assert_array_equal(result, expected)
148+
149+
def test_error_dpnp(self):
140150
ia = dpnp.array([0, 0])
141151
# values cannot be a callable function
142152
assert_raises_regex(
@@ -158,23 +168,45 @@ def test_error(self):
158168
[-1, lambda x: 1],
159169
)
160170

171+
# funclist is not array-like
172+
assert_raises_regex(
173+
TypeError,
174+
"funclist must be a sequence of scalars",
175+
dpnp.piecewise,
176+
ia,
177+
[dpnp.array([True, False])],
178+
1,
179+
)
180+
181+
assert_raises_regex(
182+
TypeError,
183+
"object of type",
184+
numpy.piecewise,
185+
ia.asnumpy(),
186+
[numpy.array([True, False])],
187+
1,
188+
)
189+
190+
@pytest.mark.parametrize("xp", [dpnp, numpy])
191+
def test_error(self, xp):
192+
ia = xp.array([0, 0])
161193
# not enough functions
162194
assert_raises_regex(
163195
ValueError,
164196
"1 or 2 functions are expected",
165-
dpnp.piecewise,
197+
xp.piecewise,
166198
ia,
167-
[dpnp.array([True, False])],
199+
[xp.array([True, False])],
168200
[],
169201
)
170202

171203
# extra function
172204
assert_raises_regex(
173205
ValueError,
174206
"1 or 2 functions are expected",
175-
dpnp.piecewise,
207+
xp.piecewise,
176208
ia,
177-
[dpnp.array([True, False])],
209+
[xp.array([True, False])],
178210
[1, 2, 3],
179211
)
180212

0 commit comments

Comments
 (0)