Skip to content

Commit 6f51c97

Browse files
committed
Adds more tests for choose
1 parent 2c13bcb commit 6f51c97

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

dpnp/tests/test_indexing.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy
66
import pytest
77
from dpctl.tensor._numpy_helper import AxisError
8+
from dpctl.tensor._type_utils import _to_device_supported_dtype
89
from dpctl.utils import ExecutionPlacementError
910
from numpy.testing import (
1011
assert_,
@@ -1513,3 +1514,103 @@ def test_choose_0d_inputs(self):
15131514
chc = dpnp.ones(sh, dtype="i4")
15141515
r = dpnp.choose(inds, [chc])
15151516
assert r == chc
1517+
1518+
def test_choose_in_overlaps_out(self):
1519+
# overlap with inds
1520+
inds = dpnp.zeros(6, dtype="i4")
1521+
inds_np = dpnp.asnumpy(inds)
1522+
chc_np = numpy.arange(6, dtype="i4")
1523+
chc = dpnp.arange(6, dtype="i4")
1524+
out = inds
1525+
expected = numpy.choose(inds_np, chc_np)
1526+
result = dpnp.choose(inds, chc, out=out)
1527+
assert_array_equal(expected, result)
1528+
assert result is out
1529+
assert (inds == out).all()
1530+
# overlap with chc
1531+
inds = dpnp.zeros(6, dtype="i4")
1532+
out = chc
1533+
expected = numpy.choose(inds_np, chc_np)
1534+
result = dpnp.choose(inds, chc, out=out)
1535+
assert_array_equal(expected, result)
1536+
assert result is out
1537+
assert (inds == out).all()
1538+
1539+
def test_choose_strided(self):
1540+
# inds strided
1541+
inds = dpnp.tile(dpnp.array([0, 1], dtype="i4"), 5)
1542+
inds_np = dpnp.asnumpy(inds)
1543+
c1 = dpnp.arange(5, dtype="i4")
1544+
c2 = dpnp.full(5, -1, dtype="i4")
1545+
chcs = [c1, c2]
1546+
chcs_np = [dpnp.asnumpy(chc) for chc in chcs]
1547+
result = dpnp.choose(inds[::-2], chcs)
1548+
expected = numpy.choose(inds_np[::-2], chcs_np)
1549+
assert_array_equal(result, expected)
1550+
# choices strided
1551+
c3 = dpnp.arange(20, dtype="i4")
1552+
c4 = dpnp.full(20, -1, dtype="i4")
1553+
chcs = [c3[::-2], c4[::-2]]
1554+
chcs_np = [dpnp.asnumpy(c3)[::-2], dpnp.asnumpy(c4)[::-2]]
1555+
result = dpnp.choose(inds, chcs)
1556+
expected = numpy.choose(inds_np, chcs_np)
1557+
assert_array_equal(result, expected)
1558+
# all strided
1559+
result = dpnp.choose(inds[::-1], chcs)
1560+
expected = numpy.choose(inds_np[::-1], chcs_np)
1561+
assert_array_equal(result, expected)
1562+
1563+
@pytest.mark.parametrize(
1564+
"indices", [[0, 2], [-5, 4]], ids=["[0, 2]", "[-5, 4]"]
1565+
)
1566+
@pytest.mark.parametrize("mode", ["clip", "wrap"])
1567+
def test_choose_modes(self, indices, mode):
1568+
chc = dpnp.array([-2, -1, 0, 1, 2], dtype="i4")
1569+
chc_np = dpnp.asnumpy(chc)
1570+
inds = dpnp.array(indices, dtype="i4")
1571+
inds_np = dpnp.asnumpy(inds)
1572+
expected = numpy.choose(inds_np, chc_np, mode=mode)
1573+
result = dpnp.choose(inds, chc, mode=mode)
1574+
assert_array_equal(expected, result)
1575+
1576+
def test_choose_arg_validation(self):
1577+
with pytest.raises(TypeError):
1578+
dpnp.choose(dpnp.zeros(()), 1)
1579+
with pytest.raises(ValueError):
1580+
dpnp.choose(dpnp.zeros(()), dpnp.ones(()), mode="err")
1581+
1582+
# based on examples from NumPy
1583+
def test_choose_broadcasting(self):
1584+
inds = dpnp.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype="i4")
1585+
inds_np = dpnp.asnumpy(inds)
1586+
chcs = dpnp.array([-10, 10])
1587+
chcs_np = dpnp.asnumpy(chcs)
1588+
result = dpnp.choose(inds, chcs)
1589+
expected = numpy.choose(inds_np, chcs_np)
1590+
assert_array_equal(result, expected)
1591+
1592+
inds = dpnp.array([0, 1]).reshape((2, 1, 1))
1593+
inds_np = dpnp.asnumpy(inds)
1594+
chc1 = dpnp.array([1, 2, 3]).reshape((1, 3, 1))
1595+
chc2 = dpnp.array([-1, -2, -3, -4, -5]).reshape(1, 1, 5)
1596+
chcs = [chc1, chc2]
1597+
chcs_np = [dpnp.asnumpy(chc) for chc in chcs]
1598+
result = dpnp.choose(inds, chcs)
1599+
expected = numpy.choose(inds_np, chcs_np)
1600+
assert_array_equal(result, expected)
1601+
1602+
@pytest.mark.parametrize("chc1_dt", get_all_dtypes())
1603+
@pytest.mark.parametrize("chc2_dt", get_all_dtypes())
1604+
def test_choose_promote_choices(self, chc1_dt, chc2_dt):
1605+
inds = dpnp.array([0, 1], dtype="i4")
1606+
inds_np = dpnp.asnumpy(inds)
1607+
chc1 = dpnp.zeros(1, dtype=chc1_dt)
1608+
chc2 = dpnp.ones(1, dtype=chc2_dt)
1609+
chcs = [chc1, chc2]
1610+
chcs_np = [dpnp.asnumpy(chc) for chc in chcs]
1611+
result = dpnp.choose(inds, chcs)
1612+
expected = numpy.choose(inds_np, chcs_np)
1613+
assert (
1614+
_to_device_supported_dtype(expected.dtype, inds.sycl_device)
1615+
== result.dtype
1616+
)

dpnp/tests/test_sycl_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,3 +2923,24 @@ def test_ix(device_0, device_1):
29232923
ixgrid = dpnp.ix_(x0, x1)
29242924
assert_sycl_queue_equal(ixgrid[0].sycl_queue, x0.sycl_queue)
29252925
assert_sycl_queue_equal(ixgrid[1].sycl_queue, x1.sycl_queue)
2926+
2927+
2928+
@pytest.mark.parametrize(
2929+
"device",
2930+
valid_devices,
2931+
ids=[device.filter_string for device in valid_devices],
2932+
)
2933+
def test_choose(device):
2934+
chc = dpnp.arange(5, dtype="i4", device=device)
2935+
chc_np = dpnp.asnumpy(chc)
2936+
2937+
inds = dpnp.array([0, 1, 3], dtype="i4", device=device)
2938+
inds_np = dpnp.asnumpy(inds)
2939+
2940+
result = dpnp.choose(inds, chc)
2941+
expected = numpy.choose(inds_np, chc_np)
2942+
assert_allclose(expected, result)
2943+
2944+
expected_queue = chc.sycl_queue
2945+
result_queue = result.sycl_queue
2946+
assert_sycl_queue_equal(result_queue, expected_queue)

dpnp/tests/test_usm_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,3 +1768,17 @@ def test_ix(usm_type_0, usm_type_1):
17681768
ixgrid = dp.ix_(x0, x1)
17691769
assert ixgrid[0].usm_type == x0.usm_type
17701770
assert ixgrid[1].usm_type == x1.usm_type
1771+
1772+
1773+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
1774+
@pytest.mark.parametrize(
1775+
"usm_type_ind", list_of_usm_types, ids=list_of_usm_types
1776+
)
1777+
def test_choose(usm_type_x, usm_type_ind):
1778+
chc = dp.arange(5, usm_type=usm_type_x)
1779+
ind = dp.array([0, 2, 4], usm_type=usm_type_ind)
1780+
z = dp.choose(ind, chc)
1781+
1782+
assert chc.usm_type == usm_type_x
1783+
assert ind.usm_type == usm_type_ind
1784+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])

0 commit comments

Comments
 (0)