@@ -229,7 +229,8 @@ def test_hypothesis( # type: ignore[no-any-decorated]
229229 cond_shape , * shapes = input_shapes
230230
231231 # cupy/cupy#8382
232- elements = {"allow_subnormal" : False } if library is Backend .CUPY else None
232+ # https://github.com/jax-ml/jax/issues/26658
233+ elements = {"allow_subnormal" : library not in (Backend .CUPY , Backend .JAX )}
233234
234235 fill_value = xp .asarray (
235236 data .draw (npst .arrays (dtype = dtype , shape = (), elements = elements ))
@@ -258,12 +259,9 @@ def f2(*args: Array) -> Array:
258259 # TODO remove asarrays once all backends support Array API 2024.12
259260 ref3 = xp .where (cond , * asarrays (f1 (* arrays ), float_fill_value , xp = xp ))
260261
261- # https://github.com/jax-ml/jax/issues/26658
262- atol = 1e-300 if library is Backend .JAX else 0
263-
264- xp_assert_close (res1 , ref1 , atol = atol , rtol = 2e-16 )
265- xp_assert_close (res2 , ref2 , atol = atol , rtol = 2e-16 )
266- xp_assert_close (res3 , ref3 , atol = atol , rtol = 2e-16 )
262+ xp_assert_close (res1 , ref1 , rtol = 2e-16 )
263+ xp_assert_equal (res2 , ref2 )
264+ xp_assert_equal (res3 , ref3 )
267265
268266
269267@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
0 commit comments