Skip to content

Commit 7fc36a6

Browse files
Add test that result_types(dtypes) works the same for Python/NumPy scalars
1 parent 6fd61c3 commit 7fc36a6

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616

1717

18+
import itertools
19+
1820
import numpy as np
1921
import pytest
2022
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
@@ -1531,3 +1533,22 @@ def test_repeat_0_size():
15311533
res = dpt.repeat(x, repetitions, axis=1)
15321534
axis_sz = 2 * x.shape[1]
15331535
assert res.shape == (0, axis_sz, 0)
1536+
1537+
1538+
def test_result_type_bug_1874():
1539+
dts_bool = [True, np.bool_(True)]
1540+
dts_ints = [int(1), np.int64(1)]
1541+
dts_floats = [float(1), np.float64(1)]
1542+
dts_complexes = [complex(1), np.complex128(1)]
1543+
1544+
# iterate over two categories
1545+
for dts1, dts2 in itertools.product(
1546+
[dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2
1547+
):
1548+
res_dts = []
1549+
# iterate over Python scalar/NumPy scalar choices within categories
1550+
for dt1, dt2 in itertools.product(dts1, dts2):
1551+
res_dt = dpt.result_type(dt1, dt2)
1552+
res_dts.append(res_dt)
1553+
# check that all results are the same
1554+
assert res_dts and all(res_dts[0] == el for el in res_dts[1:])

0 commit comments

Comments
 (0)