Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions tests/export_harnesses_multi_platform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ def setUp(self):
"operator to work around missing XLA support for pair-reductions")
)
def test_prim(self, harness: test_harnesses.Harness):
if jtu.is_device_rocm() and "multi_array_shape" in harness.fullname:
self.skipTest("Skip on ROCm: test_prim_multi_array_shape")

if "eigh_" in harness.fullname:
self.skipTest("Eigenvalues are sorted and it is not correct to compare "
"decompositions for equality.")
Expand Down Expand Up @@ -194,9 +191,6 @@ def test_all_gather(self, *, dtype):
self.export_and_compare_to_native(f, x)

def test_random_with_threefry_gpu_kernel_lowering(self):
if jtu.is_device_rocm() and jtu.get_rocm_version() > (6, 5):
self.skipTest("Skip on ROCm: test_random_with_threefry_gpu_kernel_lowering")

# On GPU we use a custom call for threefry2x32
with config.threefry_gpu_kernel_lowering(True):
def f(x):
Expand Down
3 changes: 0 additions & 3 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,9 +2599,6 @@ def testAssociativeScanOfBools(self):
def testAssociativeScanSolvingRegressionTest(self, shape):
# This test checks that the batching rule doesn't raise for a batch
# sensitive function (solve).
if jtu.is_device_rocm():
self.skipTest("Skip on ROCm: testAssociativeScanSolvingRegressionTest")

ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0)
vs = np.ones((shape, 2))

Expand Down
3 changes: 0 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,9 +2585,6 @@ def testDiagFlat(self, shape, dtype, k):
a2_shape=one_dim_array_shapes,
)
def testPolyMul(self, a1_shape, a2_shape, dtype):
if jtu.is_device_rocm() and str(self).split()[0] == "testPolyMul1":
self.skipTest("Skip on ROCm: testPolyMul")

rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polymul(arg1, arg2)
jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True)
Expand Down
17 changes: 6 additions & 11 deletions tests/pallas/gpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ def impl(q, k, v):
head_dim=(32, 64, 128,),
block_sizes=(
(
("block_q", 64),
("block_k", 64),
("block_q_dkv", 16),
("block_kv_dkv", 16),
("block_q_dq", 16),
("block_kv_dq", 64),
("block_q", 128),
("block_k", 128),
("block_q_dkv", 128),
("block_kv_dkv", 128),
("block_q_dq", 128),
("block_kv_dq", 128),
),
(
("block_q", 32),
Expand Down Expand Up @@ -291,11 +291,6 @@ def test_fused_attention_bwd(
causal,
use_segment_ids,
):
test_name = str(self).split()[0]
skip_suffix_list = [4, 6, 7, 8, 9]
if jtu.is_device_rocm() and self.INTERPRET and any(test_name.endswith(str(suffix)) for suffix in skip_suffix_list):
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd[4, 6, 7, 8, 9]")

k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
Expand Down
3 changes: 0 additions & 3 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,6 @@ def index(x_ref, idx_ref, o_ref):
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
if jtu.is_device_rocm() and self.INTERPRET:
self.skipTest("Skip on ROCm: test_matmul")

k1, k2 = random.split(random.key(0))
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
Expand Down
2 changes: 0 additions & 2 deletions tests/random_lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,6 @@ def testLogistic(self, dtype):
)
@jax.default_matmul_precision("float32")
def testOrthogonal(self, n, shape, dtype, m):
if jtu.is_device_rocm() and not (n == 0 or m == 0):
self.skipTest("Skip on ROCm: testOrthogonal[1-3, 8-9]")

if m is None:
m = n
Expand Down
3 changes: 0 additions & 3 deletions tests/scipy_signal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,6 @@ def osp_fun(x):
def testWelchWithDefaultStepArgsAgainstNumpy(
self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap,
use_window, timeaxis):
if jtu.is_device_rocm() and (not use_nperseg or (use_nperseg and use_noverlap and use_window)):
raise unittest.SkipTest("Skip on ROCm: testWelchWithDefaultStepArgsAgainstNumpy[1,2,3,7,8]")

if tuple(shape) == (2, 3, 389, 5) and nperseg == 17 and noverlap == 13:
raise unittest.SkipTest("Test fails for these inputs")
kwargs = {'axis': timeaxis}
Expand Down