Skip to content

Commit cce12ad

Browse files
author
Vahid Tavanashad
committed
address comments
1 parent c9b00d9 commit cce12ad

File tree

3 files changed

+100
-76
lines changed

3 files changed

+100
-76
lines changed

dpnp/dpnp_iface_functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def piecewise(x, condlist, funclist):
373373
funclen = len(funclist)
374374
except TypeError as e:
375375
raise TypeError("funclist must be a sequence of scalars") from e
376+
376377
if condlen == funclen:
377378
# default value is zero
378379
default_value = x_dtype.type(0)
@@ -384,11 +385,10 @@ def piecewise(x, condlist, funclist):
384385
"Callable functions are not supported currently"
385386
)
386387
if isinstance(default_value, dpnp.ndarray):
387-
default_value = default_value.astype(x_dtype)
388+
default_value = default_value.astype(x_dtype, copy=False)
388389
else:
389390
default_value = x_dtype.type(default_value)
390391
funclist = funclist[:-1]
391-
392392
else:
393393
raise ValueError(
394394
f"with {condlen} condition(s), either {condlen} or {condlen + 1} "
@@ -401,7 +401,7 @@ def piecewise(x, condlist, funclist):
401401
"Callable functions are not supported currently"
402402
)
403403
if isinstance(func, dpnp.ndarray):
404-
func = func.astype(x_dtype)
404+
func = func.astype(x_dtype, copy=False)
405405
else:
406406
func = x_dtype.type(func)
407407
dpnp.where(condition, func, default_value, out=result)

dpnp/tests/test_functional.py

Lines changed: 84 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -161,70 +161,6 @@ def test_simple(self):
161161
result = dpnp.piecewise(ia, (True, False), [1])
162162
assert_array_equal(result, expected)
163163

164-
def test_error_dpnp(self):
165-
ia = dpnp.array([0, 0])
166-
# values cannot be a callable function
167-
assert_raises_regex(
168-
NotImplementedError,
169-
"Callable functions are not supported currently",
170-
dpnp.piecewise,
171-
ia,
172-
[dpnp.array([True, False])],
173-
[lambda x: -1],
174-
)
175-
176-
# default value cannot be a callable function
177-
assert_raises_regex(
178-
NotImplementedError,
179-
"Callable functions are not supported currently",
180-
dpnp.piecewise,
181-
ia,
182-
[dpnp.array([True, False])],
183-
[-1, lambda x: 1],
184-
)
185-
186-
# funclist is not array-like
187-
assert_raises_regex(
188-
TypeError,
189-
"funclist must be a sequence of scalars",
190-
dpnp.piecewise,
191-
ia,
192-
[dpnp.array([True, False])],
193-
1,
194-
)
195-
196-
assert_raises_regex(
197-
TypeError,
198-
"object of type",
199-
numpy.piecewise,
200-
ia.asnumpy(),
201-
[numpy.array([True, False])],
202-
1,
203-
)
204-
205-
@pytest.mark.parametrize("xp", [dpnp, numpy])
206-
def test_error(self, xp):
207-
ia = xp.array([0, 0])
208-
# not enough functions
209-
assert_raises_regex(
210-
ValueError,
211-
"1 or 2 functions are expected",
212-
xp.piecewise,
213-
ia,
214-
[xp.array([True, False])],
215-
[],
216-
)
217-
218-
# extra function
219-
assert_raises_regex(
220-
ValueError,
221-
"1 or 2 functions are expected",
222-
xp.piecewise,
223-
ia,
224-
[xp.array([True, False])],
225-
[1, 2, 3],
226-
)
227-
228164
def test_two_conditions(self):
229165
a = numpy.array([1, 2])
230166
ia = dpnp.array(a)
@@ -316,3 +252,87 @@ def test_multidimensional_extrafunc(self):
316252
expected = numpy.piecewise(a, [a < 0, a >= 2], [-1, 1, 3])
317253
result = dpnp.piecewise(ia, [ia < 0, ia >= 2], [-1, 1, 3])
318254
assert_array_equal(result, expected)
255+
256+
def test_error_dpnp(self):
257+
ia = dpnp.array([0, 0])
258+
# values cannot be a callable function
259+
assert_raises_regex(
260+
NotImplementedError,
261+
"Callable functions are not supported currently",
262+
dpnp.piecewise,
263+
ia,
264+
[dpnp.array([True, False])],
265+
[lambda x: -1],
266+
)
267+
268+
# default value cannot be a callable function
269+
assert_raises_regex(
270+
NotImplementedError,
271+
"Callable functions are not supported currently",
272+
dpnp.piecewise,
273+
ia,
274+
[dpnp.array([True, False])],
275+
[-1, lambda x: 1],
276+
)
277+
278+
# funclist is not array-like
279+
assert_raises_regex(
280+
TypeError,
281+
"funclist must be a sequence of scalars",
282+
dpnp.piecewise,
283+
ia,
284+
[dpnp.array([True, False])],
285+
1,
286+
)
287+
288+
# funclist is a string
289+
assert_raises_regex(
290+
TypeError,
291+
"funclist must be a sequence of scalars",
292+
dpnp.piecewise,
293+
ia,
294+
[ia > 0],
295+
"q",
296+
)
297+
298+
assert_raises_regex(
299+
TypeError,
300+
"object of type",
301+
numpy.piecewise,
302+
ia.asnumpy(),
303+
[numpy.array([True, False])],
304+
1,
305+
)
306+
307+
@pytest.mark.parametrize("xp", [dpnp, numpy])
308+
def test_error(self, xp):
309+
ia = xp.array([0, 0])
310+
# not enough functions
311+
assert_raises_regex(
312+
ValueError,
313+
"1 or 2 functions are expected",
314+
xp.piecewise,
315+
ia,
316+
[xp.array([True, False])],
317+
[],
318+
)
319+
320+
# extra function
321+
assert_raises_regex(
322+
ValueError,
323+
"1 or 2 functions are expected",
324+
xp.piecewise,
325+
ia,
326+
[xp.array([True, False])],
327+
[1, 2, 3],
328+
)
329+
330+
# condlist is empty
331+
assert_raises_regex(
332+
IndexError,
333+
"index out of range",
334+
xp.piecewise,
335+
ia,
336+
[],
337+
[1, 2],
338+
)

dpnp/tests/test_usm_type.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -755,17 +755,21 @@ def test_apply_over_axes(usm_type):
755755
assert x.usm_type == y.usm_type
756756

757757

758-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
759-
def test_piecewise(usm_type):
760-
x = dpnp.array([0, 0], usm_type=usm_type)
761-
y = dpnp.array([True, False], usm_type=usm_type)
762-
z = dpnp.array([1, -1], usm_type=usm_type)
758+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
759+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types)
760+
@pytest.mark.parametrize("usm_type_z", list_of_usm_types)
761+
def test_piecewise(usm_type_x, usm_type_y, usm_type_z):
762+
x = dpnp.array([0, 0], usm_type=usm_type_x)
763+
y = dpnp.array([True, False], usm_type=usm_type_y)
764+
z = dpnp.array([1, -1], usm_type=usm_type_z)
763765
result = dpnp.piecewise(x, y, z)
764-
res_usm_type = result.usm_type
765766

766-
assert x.usm_type == res_usm_type
767-
assert y.usm_type == res_usm_type
768-
assert z.usm_type == res_usm_type
767+
assert x.usm_type == usm_type_x
768+
assert y.usm_type == usm_type_y
769+
assert z.usm_type == usm_type_z
770+
assert result.usm_type == du.get_coerced_usm_type(
771+
[usm_type_x, usm_type_y, usm_type_z]
772+
)
769773

770774

771775
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)