3636from .third_party .cupy import testing
3737
3838
39- def _get_output_data_type (dtype ):
40- """Return a data type specified by input `dtype` and device capabilities."""
41- dtype_float16 = any (
42- dpnp .issubdtype (dtype , t ) for t in (dpnp .bool , dpnp .int8 , dpnp .uint8 )
43- )
44- dtype_float32 = any (
45- dpnp .issubdtype (dtype , t ) for t in (dpnp .int16 , dpnp .uint16 )
46- )
47- if dtype_float16 :
48- out_dtype = dpnp .float16 if has_support_aspect16 () else dpnp .float32
49- elif dtype_float32 :
50- out_dtype = dpnp .float32
51- elif dpnp .issubdtype (dtype , dpnp .complexfloating ):
52- out_dtype = dpnp .complex64
53- if has_support_aspect64 () and dtype != dpnp .complex64 :
54- out_dtype = dpnp .complex128
55- else :
56- out_dtype = dpnp .float32
57- if has_support_aspect64 () and dtype != dpnp .float32 :
58- out_dtype = dpnp .float64
59-
60- return out_dtype
61-
62-
6339@pytest .mark .parametrize ("deg" , [True , False ])
6440class TestAngle :
6541 def test_angle_bool (self , deg ):
@@ -775,6 +751,16 @@ def test_errors(self):
775751
776752
777753class TestFix :
754+ def get_output_data_type (self , dtype ):
755+ # this is used to determine the output dtype of numpy array
756+ # which is on cpu so no need for checking has_support_aspect64
757+ if has_support_aspect16 () and dpnp .can_cast (dtype , dpnp .float16 ):
758+ return dpnp .float16
759+ if dpnp .can_cast (dtype , dpnp .float32 ):
760+ return dpnp .float32
761+ if dpnp .can_cast (dtype , dpnp .float64 ):
762+ return dpnp .float64
763+
778764 @pytest .mark .parametrize (
779765 "dt" , get_all_dtypes (no_none = True , no_complex = True )
780766 )
@@ -794,28 +780,25 @@ def test_complex(self, xp, dt):
794780 xp .fix (a )
795781
796782 @pytest .mark .parametrize (
797- "a_dt " , get_all_dtypes (no_none = True , no_bool = True , no_complex = True )
783+ "dt " , get_all_dtypes (no_none = True , no_complex = True )
798784 )
799- def test_out (self , a_dt ):
800- a = get_abs_array (
801- [[1.0 , 1.1 , 1.5 , 1.8 ], [- 1.0 , - 1.1 , - 1.5 , - 1.8 ]], a_dt
802- )
803- ia = dpnp .array (a )
804-
805- out_dt = _get_output_data_type (a .dtype )
806- out = numpy .zeros_like (a , dtype = out_dt )
807- iout = dpnp .array (out )
785+ def test_out (self , dt ):
786+ data = [[1.0 , 1.1 , 1.5 , 1.8 ], [- 1.0 , - 1.1 , - 1.5 , - 1.8 ]]
787+ a = get_abs_array (data , dtype = dt )
788+ # numpy output has the same dtype as input
789+ # dpnp output always has a floating point dtype
790+ dt_out = self .get_output_data_type (a .dtype )
791+ out = numpy .zeros_like (a , dtype = dt_out )
792+ ia , iout = dpnp .array (a ), dpnp .array (out )
808793
809794 result = dpnp .fix (ia , out = iout )
810795 expected = numpy .fix (a , out = out )
811796 assert_array_equal (result , expected )
812797
813798 @pytest .mark .skipif (not has_support_aspect16 (), reason = "no fp16 support" )
814- @pytest .mark .parametrize ("dt" , [bool , numpy .float16 ])
815- def test_out_float16 (self , dt ):
816- a = numpy .array (
817- [[1.0 , 1.1 ], [1.5 , 1.8 ], [- 1.0 , - 1.1 ], [- 1.5 , - 1.8 ]], dtype = dt
818- )
799+ def test_out_float16 (self ):
800+ data = [[1.0 , 1.1 ], [1.5 , 1.8 ], [- 1.0 , - 1.1 ], [- 1.5 , - 1.8 ]]
801+ a = numpy .array (data , dtype = numpy .float16 )
819802 out = numpy .zeros_like (a , dtype = numpy .float16 )
820803 ia , iout = dpnp .array (a ), dpnp .array (out )
821804
0 commit comments