Skip to content

Commit ed4e982

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Relax tolerance for LAX reduction test in float16.
At `float16` precision, one LAX reduction test was found to be flaky, and disabled in #25443. This change re-enables that test with a slightly relaxed tolerance instead. PiperOrigin-RevId: 706771186
1 parent b3177da commit ed4e982

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

tests/lax_numpy_reducers_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,6 @@ def np_fun(x):
342342
))
343343
def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis,
344344
keepdims, initial, inexact, promote_integers):
345-
if jtu.test_device_matches(["cpu"]) and name == "sum" and config.enable_x64.value and dtype == np.float16:
346-
raise unittest.SkipTest("sum op fails in x64 mode on CPU with dtype=float16") # b/383756018
347345
np_op = getattr(np, name)
348346
jnp_op = getattr(jnp, name)
349347
rng = rng_factory(self.rng())
@@ -364,7 +362,7 @@ def np_fun(x):
364362
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers)
365363
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
366364
args_maker = lambda: [rng(shape, dtype)]
367-
tol = {jnp.bfloat16: 3E-2}
365+
tol = {jnp.bfloat16: 3E-2, jnp.float16: 5e-3}
368366
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
369367
self._CompileAndCheck(jnp_fun, args_maker)
370368

0 commit comments

Comments
 (0)