Skip to content

Commit 1d32bbe

Browse files
Update cupy tests for common_type()
1 parent 0ffa47e commit 1d32bbe

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

dpnp/tests/third_party/cupy/test_type_routines.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
import dpnp as cupy
7-
from dpnp.tests.helper import has_support_aspect64
7+
from dpnp.tests.helper import has_support_aspect16, has_support_aspect64
88
from dpnp.tests.third_party.cupy import testing
99

1010

@@ -47,13 +47,17 @@ def test_can_cast(self, xp, from_dtype, to_dtype):
4747
return ret
4848

4949

50-
@pytest.mark.skip("dpnp.common_type() is not implemented yet")
5150
class TestCommonType(unittest.TestCase):
5251

5352
@testing.numpy_cupy_equal()
5453
def test_common_type_empty(self, xp):
5554
ret = xp.common_type()
5655
assert type(ret) is type
56+
# NumPy always returns float16 for empty input,
57+
# but dpnp returns float32 if the device does not support
58+
# 16-bit precision floating point operations
59+
if xp is numpy and not has_support_aspect16():
60+
return xp.float32
5761
return ret
5862

5963
@testing.for_all_dtypes(no_bool=True)
@@ -62,6 +66,11 @@ def test_common_type_single_argument(self, xp, dtype):
6266
array = _generate_type_routines_input(xp, dtype, "array")
6367
ret = xp.common_type(array)
6468
assert type(ret) is type
69+
# NumPy promotes integer types to float64,
70+
# but dpnp may return float32 if the device does not support
71+
# 64-bit precision floating point operations.
72+
if xp is numpy and not has_support_aspect64():
73+
return xp.float32
6574
return ret
6675

6776
@testing.for_all_dtypes_combination(
@@ -73,6 +82,8 @@ def test_common_type_two_arguments(self, xp, dtype1, dtype2):
7382
array2 = _generate_type_routines_input(xp, dtype2, "array")
7483
ret = xp.common_type(array1, array2)
7584
assert type(ret) is type
85+
if xp is numpy and not has_support_aspect64():
86+
return xp.float32
7687
return ret
7788

7889
@testing.for_all_dtypes()

0 commit comments

Comments
 (0)