99)
1010
1111import dpnp
12+ from dpnp .dpnp_utils import map_dtype_to_device
1213
1314from .helper import (
1415 assert_dtype_allclose ,
@@ -108,30 +109,6 @@ def test_umaths(test_cases):
108109 assert_allclose (result , expected , rtol = 1e-6 )
109110
110111
111- def _get_output_data_type (dtype ):
112- """Return a data type specified by input `dtype` and device capabilities."""
113- dtype_float16 = any (
114- dpnp .issubdtype (dtype , t ) for t in (dpnp .bool , dpnp .int8 , dpnp .uint8 )
115- )
116- dtype_float32 = any (
117- dpnp .issubdtype (dtype , t ) for t in (dpnp .int16 , dpnp .uint16 )
118- )
119- if dtype_float16 :
120- dt_out = dpnp .float16 if has_support_aspect16 () else dpnp .float32
121- elif dtype_float32 :
122- dt_out = dpnp .float32
123- elif dpnp .issubdtype (dtype , dpnp .complexfloating ):
124- dt_out = dpnp .complex64
125- if has_support_aspect64 () and dtype != dpnp .complex64 :
126- dt_out = dpnp .complex128
127- else :
128- dt_out = dpnp .float32
129- if has_support_aspect64 () and dtype != dpnp .float32 :
130- dt_out = dpnp .float64
131-
132- return dt_out
133-
134-
135112class TestArctan2 :
136113 @pytest .mark .parametrize (
137114 "dtype" , get_all_dtypes (no_none = True , no_complex = True )
@@ -142,10 +119,10 @@ def test_arctan2(self, dtype):
142119 expected = numpy .arctan2 (a , b )
143120
144121 ia , ib = dpnp .array (a ), dpnp .array (b )
145- dt_out = _get_output_data_type ( dtype )
122+ dt_out = map_dtype_to_device ( expected . dtype , ia . sycl_device )
146123 iout = dpnp .empty (expected .shape , dtype = dt_out )
147- result = dpnp .arctan2 (ia , ib , out = iout )
148124
125+ result = dpnp .arctan2 (ia , ib , out = iout )
149126 assert result is iout
150127 assert_dtype_allclose (result , expected )
151128
@@ -188,7 +165,7 @@ def test_copysign(self, dtype):
188165 expected = numpy .copysign (a , b )
189166
190167 ia , ib = dpnp .array (a ), dpnp .array (b )
191- dt_out = _get_output_data_type ( dtype )
168+ dt_out = map_dtype_to_device ( expected . dtype , ia . sycl_device )
192169 iout = dpnp .empty (expected .shape , dtype = dt_out )
193170 result = dpnp .copysign (ia , ib , out = iout )
194171
@@ -307,7 +284,7 @@ def test_logaddexp(self, dtype):
307284 expected = numpy .logaddexp (a , b )
308285
309286 ia , ib = dpnp .array (a ), dpnp .array (b )
310- dt_out = _get_output_data_type ( dtype )
287+ dt_out = map_dtype_to_device ( expected . dtype , ia . sycl_device )
311288 iout = dpnp .empty (expected .shape , dtype = dt_out )
312289 result = dpnp .logaddexp (ia , ib , out = iout )
313290
@@ -450,7 +427,7 @@ def test_reciprocal(self, dtype):
450427 expected = numpy .reciprocal (a )
451428
452429 ia = dpnp .array (a )
453- dt_out = _get_output_data_type ( dtype )
430+ dt_out = map_dtype_to_device ( expected . dtype , ia . sycl_device )
454431 iout = dpnp .empty (expected .shape , dtype = dt_out )
455432 result = dpnp .reciprocal (ia , out = iout )
456433
@@ -500,7 +477,7 @@ def test_basic(self, func_params, dtype):
500477 expected = getattr (numpy , func )(a )
501478
502479 ia = dpnp .array (a )
503- dt_out = _get_output_data_type ( dtype )
480+ dt_out = map_dtype_to_device ( expected . dtype , ia . sycl_device )
504481 iout = dpnp .empty (expected .shape , dtype = dt_out )
505482 result = getattr (dpnp , func )(ia , out = iout )
506483 assert result is iout
@@ -591,7 +568,7 @@ def func_params(self, request):
591568 @pytest .mark .filterwarnings ("ignore:overflow encountered:RuntimeWarning" )
592569 @pytest .mark .usefixtures ("suppress_divide_invalid_numpy_warnings" )
593570 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
594- def test_out (self , func_params , dtype ):
571+ def test_basic (self , func_params , dtype ):
595572 func = func_params ["func" ]
596573 values = func_params ["values" ]
597574 a = generate_random_numpy_array (
@@ -600,10 +577,7 @@ def test_out(self, func_params, dtype):
600577 expected = getattr (numpy , func )(a )
601578
602579 ia = dpnp .array (a )
603- if func == "square" :
604- dt_out = numpy .int8 if dtype == dpnp .bool else dtype
605- else :
606- dt_out = _get_output_data_type (dtype )
580+ dt_out = map_dtype_to_device (expected .dtype , ia .sycl_device )
607581 iout = dpnp .empty (expected .shape , dtype = dt_out )
608582 result = getattr (dpnp , func )(ia , out = iout )
609583
0 commit comments