From c21774686c89cf98f184e2a737f63b6ae4f0a9df Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Fri, 25 Jul 2025 21:29:55 +0000 Subject: [PATCH] Adjust tolerance for reciprocal in unary primitives test --- tests/pallas/ops_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index b1b32b04a8f4..1be1c2250902 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -36,6 +36,7 @@ import jax.numpy as jnp import numpy as np + if sys.platform != "win32": try: from jax.experimental.pallas import mosaic_gpu as plgpu_mgpu @@ -558,15 +559,16 @@ def kernel(*refs): (name, name, func, strategy) for name, func, strategy in UNARY_FUNCTIONS ) + @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): - if jtu.is_device_rocm and name in {"logistic", "reciprocal"}: - self.skipTest("Skip on ROCm: test_unary_primitives_[logistic,reciprocal]") + if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]: self.skip_if_mosaic_gpu() if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") + # We want exact equality here to match how JAX lowers to XLA tol = 0. if jtu.test_device_matches(["gpu"]): @@ -576,6 +578,8 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): tol = 1e-6 elif name == "exp2": tol = 1e-6 + elif name == "reciprocal": + tol = 1e-6 def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...])