@@ -161,11 +161,8 @@ def assert_dtype(
161161def assert_float_to_complex_dtype (
162162 func_name : str , * , in_dtype : DataType , out_dtype : DataType
163163):
164- if in_dtype == xp .float32 :
165- expected = xp .complex64
166- else :
167- assert in_dtype == xp .float64 # sanity check
168- expected = xp .complex128
164+ assert in_dtype in dh .real_float_dtypes # sanity check
165+ expected = dh .complex_dtype_for (in_dtype )
169166 assert_dtype (
170167 func_name , in_dtype = in_dtype , out_dtype = out_dtype , expected = expected
171168 )
@@ -174,13 +171,8 @@ def assert_float_to_complex_dtype(
174171def assert_complex_to_float_dtype (
175172 func_name : str , * , in_dtype : DataType , out_dtype : DataType , repr_name : str = "out.dtype"
176173):
177- if in_dtype == xp .complex64 :
178- expected = xp .float32
179- elif in_dtype == xp .complex128 :
180- expected = xp .float64
181- else :
182- assert in_dtype in (xp .float32 , xp .float64 ) # sanity check
183- expected = in_dtype
174+ assert in_dtype in dh .all_float_dtypes
175+ expected = dh .real_dtype_for (in_dtype )
184176 assert_dtype (
185177 func_name , in_dtype = in_dtype , out_dtype = out_dtype , expected = expected , repr_name = repr_name
186178 )
0 commit comments