Skip to content

Commit aff7714

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Fix an overly strict precision requirement in tests
They started failing after we allowed LLVM to perform contractions of adds and muls, but the difference is tiny. PiperOrigin-RevId: 701961845
1 parent 5d5b06c commit aff7714

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

tests/pallas/mosaic_gpu_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,7 @@ def layer_norm_np(x):
492492
jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
493493
* input_factor
494494
)
495-
# TODO(cperivol): find out why in this particular case we have a small-ish error.
496-
rtol = 1e-07 if input_factor > 10 else 5e-5
497-
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol)
495+
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5)
498496

499497
def test_print(self):
500498
@functools.partial(

0 commit comments

Comments
 (0)