|
36 | 36 | from .third_party.cupy import testing |
37 | 37 |
|
38 | 38 |
|
39 | | -def _get_output_data_type(dtype): |
40 | | - """Return a data type specified by input `dtype` and device capabilities.""" |
41 | | - dtype_float16 = any( |
42 | | - dpnp.issubdtype(dtype, t) for t in (dpnp.bool, dpnp.int8, dpnp.uint8) |
43 | | - ) |
44 | | - dtype_float32 = any( |
45 | | - dpnp.issubdtype(dtype, t) for t in (dpnp.int16, dpnp.uint16) |
46 | | - ) |
47 | | - if dtype_float16: |
48 | | - out_dtype = dpnp.float16 if has_support_aspect16() else dpnp.float32 |
49 | | - elif dtype_float32: |
50 | | - out_dtype = dpnp.float32 |
51 | | - elif dpnp.issubdtype(dtype, dpnp.complexfloating): |
52 | | - out_dtype = dpnp.complex64 |
53 | | - if has_support_aspect64() and dtype != dpnp.complex64: |
54 | | - out_dtype = dpnp.complex128 |
55 | | - else: |
56 | | - out_dtype = dpnp.float32 |
57 | | - if has_support_aspect64() and dtype != dpnp.float32: |
58 | | - out_dtype = dpnp.float64 |
59 | | - |
60 | | - return out_dtype |
61 | | - |
62 | | - |
63 | 39 | @pytest.mark.parametrize("deg", [True, False]) |
64 | 40 | class TestAngle: |
65 | 41 | def test_angle_bool(self, deg): |
@@ -775,6 +751,16 @@ def test_errors(self): |
775 | 751 |
|
776 | 752 |
|
777 | 753 | class TestFix: |
| 754 | + def get_output_data_type(self, dtype, device): |
| 755 | + if dpnp.can_cast(dtype, dpnp.float16): |
| 756 | + return dpnp.float16 |
| 757 | + if dpnp.can_cast(dtype, dpnp.float32): |
| 758 | + return dpnp.float32 |
| 759 | + if dpnp.can_cast(dtype, dpnp.float64): |
| 760 | + return dpnp.float64 |
| 761 | + |
| 762 | + return dpnp.default_float_type(device) |
| 763 | + |
778 | 764 | @pytest.mark.parametrize( |
779 | 765 | "dt", get_all_dtypes(no_none=True, no_complex=True) |
780 | 766 | ) |
@@ -802,7 +788,9 @@ def test_out(self, a_dt): |
802 | 788 | ) |
803 | 789 | ia = dpnp.array(a) |
804 | 790 |
|
805 | | - out_dt = _get_output_data_type(a.dtype) |
| 791 | + # numpy output has the same dtype as input |
| 792 | + # dpnp output always has a floating point dtype |
| 793 | + out_dt = self.get_output_data_type(a.dtype, a.device) |
806 | 794 | out = numpy.zeros_like(a, dtype=out_dt) |
807 | 795 | iout = dpnp.array(out) |
808 | 796 |
|
|
0 commit comments