Skip to content

Commit 904d629

Browse files
committed
Adds more tests for choose
1 parent c486e52 commit 904d629

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

dpnp/tests/test_indexing.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy
66
import pytest
77
from dpctl.tensor._numpy_helper import AxisError
8+
from dpctl.tensor._type_utils import _to_device_supported_dtype
89
from dpctl.utils import ExecutionPlacementError
910
from numpy.testing import (
1011
assert_,
@@ -1507,3 +1508,103 @@ def test_choose_0d_inputs(self):
15071508
chc = dpnp.ones(sh, dtype="i4")
15081509
r = dpnp.choose(inds, [chc])
15091510
assert r == chc
1511+
1512+
def test_choose_in_overlaps_out(self):
1513+
# overlap with inds
1514+
inds = dpnp.zeros(6, dtype="i4")
1515+
inds_np = dpnp.asnumpy(inds)
1516+
chc_np = numpy.arange(6, dtype="i4")
1517+
chc = dpnp.arange(6, dtype="i4")
1518+
out = inds
1519+
expected = numpy.choose(inds_np, chc_np)
1520+
result = dpnp.choose(inds, chc, out=out)
1521+
assert_array_equal(expected, result)
1522+
assert result is out
1523+
assert (inds == out).all()
1524+
# overlap with chc
1525+
inds = dpnp.zeros(6, dtype="i4")
1526+
out = chc
1527+
expected = numpy.choose(inds_np, chc_np)
1528+
result = dpnp.choose(inds, chc, out=out)
1529+
assert_array_equal(expected, result)
1530+
assert result is out
1531+
assert (inds == out).all()
1532+
1533+
def test_choose_strided(self):
1534+
# inds strided
1535+
inds = dpnp.tile(dpnp.array([0, 1], dtype="i4"), 5)
1536+
inds_np = dpnp.asnumpy(inds)
1537+
c1 = dpnp.arange(5, dtype="i4")
1538+
c2 = dpnp.full(5, -1, dtype="i4")
1539+
chcs = [c1, c2]
1540+
chcs_np = [dpnp.asnumpy(chc) for chc in chcs]
1541+
result = dpnp.choose(inds[::-2], chcs)
1542+
expected = numpy.choose(inds_np[::-2], chcs_np)
1543+
assert_array_equal(result, expected)
1544+
# choices strided
1545+
c3 = dpnp.arange(20, dtype="i4")
1546+
c4 = dpnp.full(20, -1, dtype="i4")
1547+
chcs = [c3[::-2], c4[::-2]]
1548+
chcs_np = [dpnp.asnumpy(c3)[::-2], dpnp.asnumpy(c4)[::-2]]
1549+
result = dpnp.choose(inds, chcs)
1550+
expected = numpy.choose(inds_np, chcs_np)
1551+
assert_array_equal(result, expected)
1552+
# all strided
1553+
result = dpnp.choose(inds[::-1], chcs)
1554+
expected = numpy.choose(inds_np[::-1], chcs_np)
1555+
assert_array_equal(result, expected)
1556+
1557+
@pytest.mark.parametrize(
1558+
"indices", [[0, 2], [-5, 4]], ids=["[0, 2]", "[-5, 4]"]
1559+
)
1560+
@pytest.mark.parametrize("mode", ["clip", "wrap"])
1561+
def test_choose_modes(self, indices, mode):
1562+
chc = dpnp.array([-2, -1, 0, 1, 2], dtype="i4")
1563+
chc_np = dpnp.asnumpy(chc)
1564+
inds = dpnp.array(indices, dtype="i4")
1565+
inds_np = dpnp.asnumpy(inds)
1566+
expected = numpy.choose(inds_np, chc_np, mode=mode)
1567+
result = dpnp.choose(inds, chc, mode=mode)
1568+
assert_array_equal(expected, result)
1569+
1570+
def test_choose_arg_validation(self):
1571+
with pytest.raises(TypeError):
1572+
dpnp.choose(dpnp.zeros(()), 1)
1573+
with pytest.raises(ValueError):
1574+
dpnp.choose(dpnp.zeros(()), dpnp.ones(()), mode="err")
1575+
1576+
# based on examples from NumPy
1577+
def test_choose_broadcasting(self):
1578+
inds = dpnp.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype="i4")
1579+
inds_np = dpnp.asnumpy(inds)
1580+
chcs = dpnp.array([-10, 10])
1581+
chcs_np = dpnp.asnumpy(chcs)
1582+
result = dpnp.choose(inds, chcs)
1583+
expected = numpy.choose(inds_np, chcs_np)
1584+
assert_array_equal(result, expected)
1585+
1586+
inds = dpnp.array([0, 1]).reshape((2, 1, 1))
1587+
inds_np = dpnp.asnumpy(inds)
1588+
chc1 = dpnp.array([1, 2, 3]).reshape((1, 3, 1))
1589+
chc2 = dpnp.array([-1, -2, -3, -4, -5]).reshape(1, 1, 5)
1590+
chcs = [chc1, chc2]
1591+
chcs_np = [dpnp.asnumpy(chc) for chc in chcs]
1592+
result = dpnp.choose(inds, chcs)
1593+
expected = numpy.choose(inds_np, chcs_np)
1594+
assert_array_equal(result, expected)
1595+
1596+
@pytest.mark.parametrize("chc1_dt", get_all_dtypes())
1597+
@pytest.mark.parametrize("chc2_dt", get_all_dtypes())
1598+
def test_choose_promote_choices(self, chc1_dt, chc2_dt):
1599+
inds = dpnp.array([0, 1], dtype="i4")
1600+
inds_np = dpnp.asnumpy(inds)
1601+
chc1 = dpnp.zeros(1, dtype=chc1_dt)
1602+
chc2 = dpnp.ones(1, dtype=chc2_dt)
1603+
chcs = [chc1, chc2]
1604+
chcs_np = [dpnp.asnumpy(chc) for chc in chcs]
1605+
result = dpnp.choose(inds, chcs)
1606+
expected = numpy.choose(inds_np, chcs_np)
1607+
assert (
1608+
_to_device_supported_dtype(expected.dtype, inds.sycl_device)
1609+
== result.dtype
1610+
)

dpnp/tests/test_sycl_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2880,3 +2880,24 @@ def test_ix(device_0, device_1):
28802880
ixgrid = dpnp.ix_(x0, x1)
28812881
assert_sycl_queue_equal(ixgrid[0].sycl_queue, x0.sycl_queue)
28822882
assert_sycl_queue_equal(ixgrid[1].sycl_queue, x1.sycl_queue)
2883+
2884+
2885+
@pytest.mark.parametrize(
2886+
"device",
2887+
valid_devices,
2888+
ids=[device.filter_string for device in valid_devices],
2889+
)
2890+
def test_choose(device):
2891+
chc = dpnp.arange(5, dtype="i4", device=device)
2892+
chc_np = dpnp.asnumpy(chc)
2893+
2894+
inds = dpnp.array([0, 1, 3], dtype="i4", device=device)
2895+
inds_np = dpnp.asnumpy(inds)
2896+
2897+
result = dpnp.choose(inds, chc)
2898+
expected = numpy.choose(inds_np, chc_np)
2899+
assert_allclose(expected, result)
2900+
2901+
expected_queue = chc.sycl_queue
2902+
result_queue = result.sycl_queue
2903+
assert_sycl_queue_equal(result_queue, expected_queue)

dpnp/tests/test_usm_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,3 +1764,17 @@ def test_ix(usm_type_0, usm_type_1):
17641764
ixgrid = dp.ix_(x0, x1)
17651765
assert ixgrid[0].usm_type == x0.usm_type
17661766
assert ixgrid[1].usm_type == x1.usm_type
1767+
1768+
1769+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
1770+
@pytest.mark.parametrize(
1771+
"usm_type_ind", list_of_usm_types, ids=list_of_usm_types
1772+
)
1773+
def test_choose(usm_type_x, usm_type_ind):
1774+
chc = dp.arange(5, usm_type=usm_type_x)
1775+
ind = dp.array([0, 2, 4], usm_type=usm_type_ind)
1776+
z = dp.choose(ind, chc)
1777+
1778+
assert chc.usm_type == usm_type_x
1779+
assert ind.usm_type == usm_type_ind
1780+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])

0 commit comments

Comments
 (0)