Skip to content

Commit c3d57ee

Browse files
committed
Adds new choose tests
Also corrects errors for unexpected dtype to TypeError to match NumPy
1 parent deb0b47 commit c3d57ee

File tree

2 files changed

+80
-17
lines changed

2 files changed

+80
-17
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _choose_run(inds, chcs, q, usm_type, out=None, mode=0):
158158
)
159159

160160
if chcs[0].dtype != out.dtype:
161-
raise ValueError(
161+
raise TypeError(
162162
f"Output array of type {chcs[0].dtype} is needed, "
163163
f"got {out.dtype}"
164164
)
@@ -267,7 +267,12 @@ def choose(a, choices, out=None, mode="wrap"):
267267
inds = dpnp.get_usm_ndarray(a)
268268
ind_dt = inds.dtype
269269
if not dpnp.issubdtype(ind_dt, dpnp.integer):
270-
raise ValueError("input index array must be of integer data type")
270+
# NumPy will cast up to to int64 in general but
271+
# int32 is more than safe for bool
272+
if ind_dt == dpnp.bool:
273+
inds = dpt.astype(inds, dpt.int32)
274+
else:
275+
raise TypeError("input index array must be of integer data type")
271276

272277
choices = _build_choices_list(choices)
273278

dpnp/tests/test_indexing.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -879,21 +879,6 @@ def test_mode_clip(self):
879879
assert (result == dpnp.array([-2, 0, -2, 2])).all()
880880

881881

882-
def test_choose():
883-
a = numpy.r_[:4]
884-
ia = dpnp.array(a)
885-
b = numpy.r_[-4:0]
886-
ib = dpnp.array(b)
887-
c = numpy.r_[100:500:100]
888-
ic = dpnp.array(c)
889-
890-
inds_np = numpy.zeros(4, dtype="i4")
891-
inds = dpnp.zeros(4, dtype="i4")
892-
expected = numpy.choose(inds_np, [a, b, c])
893-
result = dpnp.choose(inds, [ia, ib, ic])
894-
assert_array_equal(expected, result)
895-
896-
897882
@pytest.mark.parametrize("val", [-1, 0, 1], ids=["-1", "0", "1"])
898883
@pytest.mark.parametrize(
899884
"array",
@@ -1448,3 +1433,76 @@ def test_compress_strided(self):
14481433
result = dpnp.compress(cond, a)
14491434
expected = numpy.compress(cond_np, a_np)
14501435
assert_array_equal(result, expected)
1436+
1437+
1438+
class TestChoose:
1439+
def test_choose_basic(self):
1440+
indices = [0, 1, 0]
1441+
# use a single array for choices
1442+
chcs_np = numpy.arange(2 * len(indices))
1443+
chcs = dpnp.arange(2 * len(indices))
1444+
inds_np = numpy.array(indices)
1445+
inds = dpnp.array(indices)
1446+
expected = numpy.choose(inds_np, chcs_np)
1447+
result = dpnp.choose(inds, chcs)
1448+
assert_array_equal(expected, result)
1449+
1450+
def test_choose_method_basic(self):
1451+
indices = [0, 1, 2]
1452+
# use a single array for choices
1453+
chcs_np = numpy.arange(3 * len(indices))
1454+
chcs = dpnp.arange(3 * len(indices))
1455+
inds_np = numpy.array(indices)
1456+
inds = dpnp.array(indices)
1457+
expected = inds_np.choose(chcs_np)
1458+
result = inds.choose(chcs)
1459+
assert_array_equal(expected, result)
1460+
1461+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1462+
def test_choose_inds_all_dtypes(self, dtype):
1463+
if not dpnp.issubdtype(dtype, dpnp.integer) and dtype != dpnp.bool:
1464+
inds = dpnp.zeros(1, dtype=dtype)
1465+
chcs = dpnp.ones(1, dtype=dtype)
1466+
with pytest.raises(TypeError):
1467+
dpnp.choose(inds, chcs)
1468+
else:
1469+
inds_np = numpy.array([1, 0, 1], dtype=dtype)
1470+
inds = dpnp.array(inds_np)
1471+
chcs_np = numpy.array([1, 2, 3], dtype=dtype)
1472+
chcs = dpnp.array(chcs_np)
1473+
expected = numpy.choose(inds_np, chcs_np)
1474+
result = dpnp.choose(inds, chcs)
1475+
assert_array_equal(expected, result)
1476+
1477+
def test_choose_invalid_out_errors(self):
1478+
q1 = dpctl.SyclQueue()
1479+
q2 = dpctl.SyclQueue()
1480+
chcs = dpnp.ones(10, dtype="i4", sycl_queue=q1)
1481+
inds = dpnp.zeros(10, dtype="i4", sycl_queue=q1)
1482+
out_bad_shape = dpnp.empty(11, dtype=chcs.dtype, sycl_queue=q1)
1483+
with pytest.raises(ValueError):
1484+
dpnp.choose(inds, [chcs], out=out_bad_shape)
1485+
out_bad_queue = dpnp.empty(chcs.shape, dtype=chcs.dtype, sycl_queue=q2)
1486+
with pytest.raises(ExecutionPlacementError):
1487+
dpnp.choose(inds, [chcs], out=out_bad_queue)
1488+
out_bad_dt = dpnp.empty(chcs.shape, dtype="i8", sycl_queue=q1)
1489+
with pytest.raises(TypeError):
1490+
dpnp.choose(inds, [chcs], out=out_bad_dt)
1491+
out_read_only = dpnp.empty(chcs.shape, dtype=chcs.dtype, sycl_queue=q1)
1492+
out_read_only.flags.writable = False
1493+
with pytest.raises(ValueError):
1494+
dpnp.choose(inds, [chcs], out=out_read_only)
1495+
1496+
def test_choose_empty(self):
1497+
sh = (10, 0, 5)
1498+
inds = dpnp.ones(sh, dtype="i4")
1499+
chcs = dpnp.ones(sh)
1500+
r = dpnp.choose(inds, chcs)
1501+
assert r.shape == sh
1502+
r = dpnp.choose(inds, (chcs,) * 2)
1503+
assert r.shape == sh
1504+
inds = dpnp.unstack(inds)[0]
1505+
r = dpnp.choose(inds, chcs)
1506+
assert r.shape == sh[1:]
1507+
r = dpnp.choose(inds, [chcs])
1508+
assert r.shape == sh

0 commit comments

Comments
 (0)