|
5 | 5 | import numpy |
6 | 6 | import pytest |
7 | 7 | from dpctl.tensor._numpy_helper import AxisError |
| 8 | +from dpctl.tensor._type_utils import _to_device_supported_dtype |
8 | 9 | from dpctl.utils import ExecutionPlacementError |
9 | 10 | from numpy.testing import ( |
10 | 11 | assert_, |
@@ -1507,3 +1508,103 @@ def test_choose_0d_inputs(self): |
1507 | 1508 | chc = dpnp.ones(sh, dtype="i4") |
1508 | 1509 | r = dpnp.choose(inds, [chc]) |
1509 | 1510 | assert r == chc |
| 1511 | + |
| 1512 | + def test_choose_in_overlaps_out(self): |
| 1513 | + # overlap with inds |
| 1514 | + inds = dpnp.zeros(6, dtype="i4") |
| 1515 | + inds_np = dpnp.asnumpy(inds) |
| 1516 | + chc_np = numpy.arange(6, dtype="i4") |
| 1517 | + chc = dpnp.arange(6, dtype="i4") |
| 1518 | + out = inds |
| 1519 | + expected = numpy.choose(inds_np, chc_np) |
| 1520 | + result = dpnp.choose(inds, chc, out=out) |
| 1521 | + assert_array_equal(expected, result) |
| 1522 | + assert result is out |
| 1523 | + assert (inds == out).all() |
| 1524 | + # overlap with chc |
| 1525 | + inds = dpnp.zeros(6, dtype="i4") |
| 1526 | + out = chc |
| 1527 | + expected = numpy.choose(inds_np, chc_np) |
| 1528 | + result = dpnp.choose(inds, chc, out=out) |
| 1529 | + assert_array_equal(expected, result) |
| 1530 | + assert result is out |
| 1531 | + assert (inds == out).all() |
| 1532 | + |
| 1533 | + def test_choose_strided(self): |
| 1534 | + # inds strided |
| 1535 | + inds = dpnp.tile(dpnp.array([0, 1], dtype="i4"), 5) |
| 1536 | + inds_np = dpnp.asnumpy(inds) |
| 1537 | + c1 = dpnp.arange(5, dtype="i4") |
| 1538 | + c2 = dpnp.full(5, -1, dtype="i4") |
| 1539 | + chcs = [c1, c2] |
| 1540 | + chcs_np = [dpnp.asnumpy(chc) for chc in chcs] |
| 1541 | + result = dpnp.choose(inds[::-2], chcs) |
| 1542 | + expected = numpy.choose(inds_np[::-2], chcs_np) |
| 1543 | + assert_array_equal(result, expected) |
| 1544 | + # choices strided |
| 1545 | + c3 = dpnp.arange(20, dtype="i4") |
| 1546 | + c4 = dpnp.full(20, -1, dtype="i4") |
| 1547 | + chcs = [c3[::-2], c4[::-2]] |
| 1548 | + chcs_np = [dpnp.asnumpy(c3)[::-2], dpnp.asnumpy(c4)[::-2]] |
| 1549 | + result = dpnp.choose(inds, chcs) |
| 1550 | + expected = numpy.choose(inds_np, chcs_np) |
| 1551 | + assert_array_equal(result, expected) |
| 1552 | + # all strided |
| 1553 | + result = dpnp.choose(inds[::-1], chcs) |
| 1554 | + expected = numpy.choose(inds_np[::-1], chcs_np) |
| 1555 | + assert_array_equal(result, expected) |
| 1556 | + |
| 1557 | + @pytest.mark.parametrize( |
| 1558 | + "indices", [[0, 2], [-5, 4]], ids=["[0, 2]", "[-5, 4]"] |
| 1559 | + ) |
| 1560 | + @pytest.mark.parametrize("mode", ["clip", "wrap"]) |
| 1561 | + def test_choose_modes(self, indices, mode): |
| 1562 | + chc = dpnp.array([-2, -1, 0, 1, 2], dtype="i4") |
| 1563 | + chc_np = dpnp.asnumpy(chc) |
| 1564 | + inds = dpnp.array(indices, dtype="i4") |
| 1565 | + inds_np = dpnp.asnumpy(inds) |
| 1566 | + expected = numpy.choose(inds_np, chc_np, mode=mode) |
| 1567 | + result = dpnp.choose(inds, chc, mode=mode) |
| 1568 | + assert_array_equal(expected, result) |
| 1569 | + |
| 1570 | + def test_choose_arg_validation(self): |
| 1571 | + with pytest.raises(TypeError): |
| 1572 | + dpnp.choose(dpnp.zeros(()), 1) |
| 1573 | + with pytest.raises(ValueError): |
| 1574 | + dpnp.choose(dpnp.zeros(()), dpnp.ones(()), mode="err") |
| 1575 | + |
| 1576 | + # based on examples from NumPy |
| 1577 | + def test_choose_broadcasting(self): |
| 1578 | + inds = dpnp.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype="i4") |
| 1579 | + inds_np = dpnp.asnumpy(inds) |
| 1580 | + chcs = dpnp.array([-10, 10]) |
| 1581 | + chcs_np = dpnp.asnumpy(chcs) |
| 1582 | + result = dpnp.choose(inds, chcs) |
| 1583 | + expected = numpy.choose(inds_np, chcs_np) |
| 1584 | + assert_array_equal(result, expected) |
| 1585 | + |
| 1586 | + inds = dpnp.array([0, 1]).reshape((2, 1, 1)) |
| 1587 | + inds_np = dpnp.asnumpy(inds) |
| 1588 | + chc1 = dpnp.array([1, 2, 3]).reshape((1, 3, 1)) |
| 1589 | + chc2 = dpnp.array([-1, -2, -3, -4, -5]).reshape(1, 1, 5) |
| 1590 | + chcs = [chc1, chc2] |
| 1591 | + chcs_np = [dpnp.asnumpy(chc) for chc in chcs] |
| 1592 | + result = dpnp.choose(inds, chcs) |
| 1593 | + expected = numpy.choose(inds_np, chcs_np) |
| 1594 | + assert_array_equal(result, expected) |
| 1595 | + |
| 1596 | + @pytest.mark.parametrize("chc1_dt", get_all_dtypes()) |
| 1597 | + @pytest.mark.parametrize("chc2_dt", get_all_dtypes()) |
| 1598 | + def test_choose_promote_choices(self, chc1_dt, chc2_dt): |
| 1599 | + inds = dpnp.array([0, 1], dtype="i4") |
| 1600 | + inds_np = dpnp.asnumpy(inds) |
| 1601 | + chc1 = dpnp.zeros(1, dtype=chc1_dt) |
| 1602 | + chc2 = dpnp.ones(1, dtype=chc2_dt) |
| 1603 | + chcs = [chc1, chc2] |
| 1604 | + chcs_np = [dpnp.asnumpy(chc) for chc in chcs] |
| 1605 | + result = dpnp.choose(inds, chcs) |
| 1606 | + expected = numpy.choose(inds_np, chcs_np) |
| 1607 | + assert ( |
| 1608 | + _to_device_supported_dtype(expected.dtype, inds.sycl_device) |
| 1609 | + == result.dtype |
| 1610 | + ) |
0 commit comments