Skip to content

Commit 11d5472

Browse files
committed
TST: use builtin hypothesis dtype infra
1 parent d6336b0 commit 11d5472

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/test_funcs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_device(self, xp: ModuleType, device: Device):
206206
@given(
207207
n_arrays=st.integers(min_value=1, max_value=3),
208208
rng_seed=st.integers(min_value=1000000000, max_value=9999999999),
209-
dtype=st.sampled_from((np.float32, np.float64)),
209+
dtype=npst.floating_dtypes(sizes=(32, 64)),
210210
p=st.floats(min_value=0, max_value=1),
211211
data=st.data(),
212212
)
@@ -223,7 +223,7 @@ def test_hypothesis(
223223
if (
224224
library.like(Backend.NUMPY)
225225
and NUMPY_VERSION < (2, 0)
226-
and dtype is np.float32
226+
and dtype.type is np.float32
227227
):
228228
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
229229

@@ -236,17 +236,17 @@ def test_hypothesis(
236236
elements = {"allow_subnormal": not library.like(Backend.CUPY, Backend.JAX)}
237237

238238
fill_value = xp.asarray(
239-
data.draw(npst.arrays(dtype=dtype, shape=(), elements=elements))
239+
data.draw(npst.arrays(dtype=dtype.type, shape=(), elements=elements))
240240
)
241241
float_fill_value = float(fill_value)
242-
if library is Backend.CUPY and dtype is np.float32:
242+
if library is Backend.CUPY and dtype.type is np.float32:
243243
# Avoid data-dependent dtype promotion when encountering subnormals
244244
# close to the max float32 value
245245
float_fill_value = float(np.clip(float_fill_value, -1e38, 1e38))
246246

247247
arrays = tuple(
248248
xp.asarray(
249-
data.draw(npst.arrays(dtype=dtype, shape=shape, elements=elements))
249+
data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements))
250250
)
251251
for shape in shapes
252252
)

0 commit comments

Comments
 (0)