@@ -229,14 +229,17 @@ def test_hypothesis( # type: ignore[no-any-decorated]
229229 cond_shape , * shapes = input_shapes
230230
231231 # cupy/cupy#8382
232- elements = {"allow_subnormal" : False } if library is Backend .CUPY else None
232+ # https://github.com/jax-ml/jax/issues/26658
233+ elements = {"allow_subnormal" : library not in (Backend .CUPY , Backend .JAX )}
233234
234235 fill_value = xp .asarray (
235236 data .draw (npst .arrays (dtype = dtype , shape = (), elements = elements ))
236237 )
237238 float_fill_value = float (fill_value )
238239 arrays = tuple (
239- xp .asarray (data .draw (npst .arrays (dtype = dtype , shape = shape )))
240+ xp .asarray (
241+ data .draw (npst .arrays (dtype = dtype , shape = shape , elements = elements ))
242+ )
240243 for shape in shapes
241244 )
242245
@@ -258,12 +261,9 @@ def f2(*args: Array) -> Array:
258261 # TODO remove asarrays once all backends support Array API 2024.12
259262 ref3 = xp .where (cond , * asarrays (f1 (* arrays ), float_fill_value , xp = xp ))
260263
261- # https://github.com/jax-ml/jax/issues/26658
262- atol = 1e-300 if library is Backend .JAX else 0
263-
264- xp_assert_close (res1 , ref1 , atol = atol , rtol = 2e-16 )
265- xp_assert_close (res2 , ref2 , atol = atol , rtol = 2e-16 )
266- xp_assert_close (res3 , ref3 , atol = atol , rtol = 2e-16 )
264+ xp_assert_close (res1 , ref1 , rtol = 2e-16 )
265+ xp_assert_equal (res2 , ref2 )
266+ xp_assert_equal (res3 , ref3 )
267267
268268
269269@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
0 commit comments