diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index b1b32b04a8f4..1be1c2250902 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -36,6 +36,7 @@ import jax.numpy as jnp import numpy as np + if sys.platform != "win32": try: from jax.experimental.pallas import mosaic_gpu as plgpu_mgpu @@ -558,15 +559,16 @@ def kernel(*refs): (name, name, func, strategy) for name, func, strategy in UNARY_FUNCTIONS ) + @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): - if jtu.is_device_rocm and name in {"logistic", "reciprocal"}: - self.skipTest("Skip on ROCm: test_unary_primitives_[logistic,reciprocal]") + if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]: self.skip_if_mosaic_gpu() if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") + # We want exact equality here to match how JAX lowers to XLA tol = 0. if jtu.test_device_matches(["gpu"]): @@ -576,6 +578,8 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): tol = 1e-6 elif name == "exp2": tol = 1e-6 + elif name == "reciprocal": + tol = 1e-6 def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...])