Skip to content

Commit e30e839

Browse files
committed
MAINT: remove explicit xp.complex128 from scalars() strategy
NB: while at it, note that complex64 cases used to generate values of 32 bits for *both real and imag parts*: so in a range of complex32 / two float16.
1 parent e0f3e37 commit e30e839

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from . import xps
2121
from ._array_module import _UndefinedStub
2222
from ._array_module import bool as bool_dtype
23-
from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
23+
from ._array_module import broadcast_to, eye, full
2424
from .stubs import category_to_funcs
2525
from .pytest_helpers import nargs
2626
from .typing import Array, DataType, Scalar, Shape
@@ -465,26 +465,21 @@ def scalars(draw, dtypes, finite=False, **kwds):
465465
m, M = dh.dtype_ranges[dtype]
466466
min_value = kwds.get('min_value', m)
467467
max_value = kwds.get('max_value', M)
468-
469468
return draw(integers(min_value, max_value))
469+
470470
elif dtype == bool_dtype:
471471
return draw(booleans())
472-
elif dtype == float64:
473-
if finite:
474-
return draw(floats(allow_nan=False, allow_infinity=False, **kwds))
475-
return draw(floats(), **kwds)
476-
elif dtype == float32:
477-
if finite:
478-
return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds))
479-
return draw(floats(width=32, **kwds))
480-
elif dtype == complex64:
481-
if finite:
482-
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
483-
return draw(complex_numbers(width=32))
484-
elif dtype == complex128:
485-
if finite:
486-
return draw(complex_numbers(allow_nan=False, allow_infinity=False))
487-
return draw(complex_numbers())
472+
473+
elif dtype in dh.real_float_dtypes:
474+
f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict()
475+
width = dh.dtype_nbits[dtype] # 32 or 64
476+
return draw(floats(width=width, **f_kwds, **kwds))
477+
478+
elif dtype in dh.complex_dtypes:
479+
f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict()
480+
width = dh.dtype_nbits[dtype] # 64 or 128
481+
return draw(complex_numbers(width=width, **f_kwds, **kwds))
482+
488483
else:
489484
raise ValueError(f"Unrecognized dtype {dtype}")
490485

0 commit comments

Comments
 (0)