Skip to content

Commit f607523

Browse files
committed
unskip selected test cases
Unskipped the following tests: - lax_control_flow_test: testAssociativeScanSolvingRegressionTest - scipy_signal_test: testWelchWithDefaultStepArgsAgainst - sparse_test: cuSparseTest, SparseObjectTest - random_lax_test: DistributionsTest::testOrthogonal{1,7,8,9} - pallas_test: PallasCallInterpretTest::test_matmul (various shapes/dtypes) - lax_numpy_test: testPolyMul1 - gpu_ops_test: FusedAttentionInterpretTest::test_fused_attention_bwd{4,6,7,8} (change the sample test data) - export_harnesses_multi_platform_test: PrimitiveTest::test_prim_qr_multi_array_shape (various) - ann_test: pass for whole file
1 parent d9b4b40 commit f607523

File tree

7 files changed

+9
-29
lines changed

7 files changed

+9
-29
lines changed

tests/export_harnesses_multi_platform_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ def setUp(self):
7979
"operator to work around missing XLA support for pair-reductions")
8080
)
8181
def test_prim(self, harness: test_harnesses.Harness):
82-
if jtu.is_device_rocm() and "multi_array_shape" in harness.fullname:
83-
self.skipTest("Skip on ROCm: test_prim_multi_array_shape")
84-
8582
if "eigh_" in harness.fullname:
8683
self.skipTest("Eigenvalues are sorted and it is not correct to compare "
8784
"decompositions for equality.")
@@ -194,9 +191,6 @@ def test_all_gather(self, *, dtype):
194191
self.export_and_compare_to_native(f, x)
195192

196193
def test_random_with_threefry_gpu_kernel_lowering(self):
197-
if jtu.is_device_rocm() and jtu.get_rocm_version() > (6, 5):
198-
self.skipTest("Skip on ROCm: test_random_with_threefry_gpu_kernel_lowering")
199-
200194
# On GPU we use a custom call for threefry2x32
201195
with config.threefry_gpu_kernel_lowering(True):
202196
def f(x):

tests/lax_control_flow_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,8 +2599,7 @@ def testAssociativeScanOfBools(self):
25992599
def testAssociativeScanSolvingRegressionTest(self, shape):
26002600
# This test checks that the batching rule doesn't raise for a batch
26012601
# sensitive function (solve).
2602-
if jtu.is_device_rocm():
2603-
self.skipTest("Skip on ROCm: testAssociativeScanSolvingRegressionTest")
2602+
26042603

26052604
ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0)
26062605
vs = np.ones((shape, 2))

tests/lax_numpy_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,8 +2585,8 @@ def testDiagFlat(self, shape, dtype, k):
25852585
a2_shape=one_dim_array_shapes,
25862586
)
25872587
def testPolyMul(self, a1_shape, a2_shape, dtype):
2588-
if jtu.is_device_rocm() and str(self).split()[0] == "testPolyMul1":
2589-
self.skipTest("Skip on ROCm: testPolyMul")
2588+
# if jtu.is_device_rocm() and str(self).split()[0] == "testPolyMul1":
2589+
# self.skipTest("Skip on ROCm: testPolyMul")
25902590

25912591
rng = jtu.rand_default(self.rng())
25922592
np_fun = lambda arg1, arg2: np.polymul(arg1, arg2)

tests/pallas/gpu_ops_test.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ def impl(q, k, v):
235235
head_dim=(32, 64, 128,),
236236
block_sizes=(
237237
(
238-
("block_q", 64),
239-
("block_k", 64),
240-
("block_q_dkv", 16),
241-
("block_kv_dkv", 16),
242-
("block_q_dq", 16),
243-
("block_kv_dq", 64),
238+
("block_q", 128),
239+
("block_k", 128),
240+
("block_q_dkv", 128),
241+
("block_kv_dkv", 128),
242+
("block_q_dq", 128),
243+
("block_kv_dq", 128),
244244
),
245245
(
246246
("block_q", 32),
@@ -291,11 +291,6 @@ def test_fused_attention_bwd(
291291
causal,
292292
use_segment_ids,
293293
):
294-
test_name = str(self).split()[0]
295-
skip_suffix_list = [4, 6, 7, 8, 9]
296-
if jtu.is_device_rocm() and self.INTERPRET and any(test_name.endswith(str(suffix)) for suffix in skip_suffix_list):
297-
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd[4, 6, 7, 8, 9]")
298-
299294
k1, k2, k3 = random.split(random.key(0), 3)
300295
q = random.normal(
301296
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16

tests/pallas/pallas_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,6 @@ def index(x_ref, idx_ref, o_ref):
549549
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
550550
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
551551
self.skipTest("On TPU the test works only in interpret mode")
552-
if jtu.is_device_rocm() and self.INTERPRET:
553-
self.skipTest("Skip on ROCm: test_matmul")
554-
555552
k1, k2 = random.split(random.key(0))
556553
x = random.normal(k1, (m, k), dtype=dtype)
557554
y = random.normal(k2, (k, n), dtype=dtype)

tests/random_lax_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,8 +740,6 @@ def testLogistic(self, dtype):
740740
)
741741
@jax.default_matmul_precision("float32")
742742
def testOrthogonal(self, n, shape, dtype, m):
743-
if jtu.is_device_rocm() and not (n == 0 or m == 0):
744-
self.skipTest("Skip on ROCm: testOrthogonal[1-3, 8-9]")
745743

746744
if m is None:
747745
m = n

tests/scipy_signal_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,6 @@ def osp_fun(x):
350350
def testWelchWithDefaultStepArgsAgainstNumpy(
351351
self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap,
352352
use_window, timeaxis):
353-
if jtu.is_device_rocm() and (not use_nperseg or (use_nperseg and use_noverlap and use_window)):
354-
raise unittest.SkipTest("Skip on ROCm: testWelchWithDefaultStepArgsAgainstNumpy[1,2,3,7,8]")
355-
356353
if tuple(shape) == (2, 3, 389, 5) and nperseg == 17 and noverlap == 13:
357354
raise unittest.SkipTest("Test fails for these inputs")
358355
kwargs = {'axis': timeaxis}

0 commit comments

Comments
 (0)