Skip to content

Commit 70a85c3

Browse files
committed
Test choose with non-overlapping out and choose with invalid choice object
1 parent 12f4ccf commit 70a85c3

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

dpnp/tests/test_indexing.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,19 @@ def test_choose_0d_inputs(self):
15091509
r = dpnp.choose(inds, [chc])
15101510
assert r == chc
15111511

1512+
def test_choose_out_keyword(self):
1513+
inds = dpnp.tile(dpnp.array([0, 1, 2], dtype="i4"), (5, 3))
1514+
inds_np = dpnp.asnumpy(inds)
1515+
chc1 = dpnp.zeros(9, dtype="f4")
1516+
chc2 = dpnp.ones(9, dtype="f4")
1517+
chc3 = dpnp.full(9, 2, dtype="f4")
1518+
chcs = [chc1, chc2, chc3]
1519+
chcs_np = [dpnp.asnumpy(chc) for chc in chcs]
1520+
out = dpnp.empty_like(inds, dtype="f4")
1521+
dpnp.choose(inds, chcs, out=out)
1522+
expected = numpy.choose(inds_np, chcs_np)
1523+
assert_array_equal(out, expected)
1524+
15121525
def test_choose_in_overlaps_out(self):
15131526
# overlap with inds
15141527
inds = dpnp.zeros(6, dtype="i4")
@@ -1568,8 +1581,10 @@ def test_choose_modes(self, indices, mode):
15681581
assert_array_equal(expected, result)
15691582

15701583
def test_choose_arg_validation(self):
1584+
# invalid choices
15711585
with pytest.raises(TypeError):
1572-
dpnp.choose(dpnp.zeros(()), 1)
1586+
dpnp.choose(dpnp.zeros((), dtype="i4"), 1)
1587+
# invalid mode keyword
15731588
with pytest.raises(ValueError):
15741589
dpnp.choose(dpnp.zeros(()), dpnp.ones(()), mode="err")
15751590

0 commit comments

Comments
 (0)