|
20 | 20 | from . import xps |
21 | 21 | from ._array_module import _UndefinedStub |
22 | 22 | from ._array_module import bool as bool_dtype |
23 | | -from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128 |
| 23 | +from ._array_module import broadcast_to, eye, full |
24 | 24 | from .stubs import category_to_funcs |
25 | 25 | from .pytest_helpers import nargs |
26 | 26 | from .typing import Array, DataType, Scalar, Shape |
@@ -465,26 +465,21 @@ def scalars(draw, dtypes, finite=False, **kwds): |
465 | 465 | m, M = dh.dtype_ranges[dtype] |
466 | 466 | min_value = kwds.get('min_value', m) |
467 | 467 | max_value = kwds.get('max_value', M) |
468 | | - |
469 | 468 | return draw(integers(min_value, max_value)) |
| 469 | + |
470 | 470 | elif dtype == bool_dtype: |
471 | 471 | return draw(booleans()) |
472 | | - elif dtype == float64: |
473 | | - if finite: |
474 | | - return draw(floats(allow_nan=False, allow_infinity=False, **kwds)) |
475 | | - return draw(floats(), **kwds) |
476 | | - elif dtype == float32: |
477 | | - if finite: |
478 | | - return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds)) |
479 | | - return draw(floats(width=32, **kwds)) |
480 | | - elif dtype == complex64: |
481 | | - if finite: |
482 | | - return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False)) |
483 | | - return draw(complex_numbers(width=32)) |
484 | | - elif dtype == complex128: |
485 | | - if finite: |
486 | | - return draw(complex_numbers(allow_nan=False, allow_infinity=False)) |
487 | | - return draw(complex_numbers()) |
| 472 | + |
| 473 | + elif dtype in dh.real_float_dtypes: |
| 474 | + f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict() |
| 475 | + width = dh.dtype_nbits[dtype] # 32 or 64 |
| 476 | + return draw(floats(width=width, **f_kwds, **kwds)) |
| 477 | + |
| 478 | + elif dtype in dh.complex_dtypes: |
| 479 | + f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict() |
| 480 | + width = dh.dtype_nbits[dtype] # 64 or 128 |
| 481 | + return draw(complex_numbers(width=width, **f_kwds, **kwds)) |
| 482 | + |
488 | 483 | else: |
489 | 484 | raise ValueError(f"Unrecognized dtype {dtype}") |
490 | 485 |
|
|
0 commit comments