Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import jax.numpy as jnp
import numpy as np


if sys.platform != "win32":
try:
from jax.experimental.pallas import mosaic_gpu as plgpu_mgpu
Expand Down Expand Up @@ -558,15 +559,16 @@ def kernel(*refs):
(name, name, func, strategy)
for name, func, strategy in UNARY_FUNCTIONS
)

@hp.given(hps.data())
def test_unary_primitives(self, name, func, shape_dtype_strategy, data):
if jtu.is_device_rocm and name in {"logistic", "reciprocal"}:
self.skipTest("Skip on ROCm: test_unary_primitives_[logistic,reciprocal]")

if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]:
self.skip_if_mosaic_gpu()

if self.INTERPRET:
self.skipTest("This hypothesis test is slow, even more so in interpret mode.")

# We want exact equality here to match how JAX lowers to XLA
tol = 0.
if jtu.test_device_matches(["gpu"]):
Expand All @@ -576,6 +578,8 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data):
tol = 1e-6
elif name == "exp2":
tol = 1e-6
elif name == "reciprocal":
tol = 1e-6

def kernel(x_ref, y_ref):
y_ref[...] = func(x_ref[...])
Expand Down