Skip to content

Commit 8008a3f

Browse files
committed
MAINT: remove explicit complex128 from fft assertions, of real<->complex
Reuse real_dtype_for/complex_dtype_for, which encapsulate the same mappings.
1 parent e30e839 commit 8008a3f

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,8 @@ def assert_dtype(
161161
def assert_float_to_complex_dtype(
162162
func_name: str, *, in_dtype: DataType, out_dtype: DataType
163163
):
164-
if in_dtype == xp.float32:
165-
expected = xp.complex64
166-
else:
167-
assert in_dtype == xp.float64 # sanity check
168-
expected = xp.complex128
164+
assert in_dtype in dh.real_float_dtypes # sanity check
165+
expected = dh.complex_dtype_for(in_dtype)
169166
assert_dtype(
170167
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
171168
)
@@ -174,13 +171,8 @@ def assert_float_to_complex_dtype(
174171
def assert_complex_to_float_dtype(
175172
func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype"
176173
):
177-
if in_dtype == xp.complex64:
178-
expected = xp.float32
179-
elif in_dtype == xp.complex128:
180-
expected = xp.float64
181-
else:
182-
assert in_dtype in (xp.float32, xp.float64) # sanity check
183-
expected = in_dtype
174+
assert in_dtype in dh.all_float_dtypes
175+
expected = dh.real_dtype_for(in_dtype)
184176
assert_dtype(
185177
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name
186178
)

0 commit comments

Comments
 (0)