From ff886940f55e54dbb54ca0306fabfca6b149ddf6 Mon Sep 17 00:00:00 2001 From: Thanh Binh Date: Wed, 10 Sep 2025 15:38:20 -0500 Subject: [PATCH] unskip selected test cases Unskipped the following tests: - lax_control_flow_test: testAssociativeScanSolvingRegressionTest - scipy_signal_test: testWelchWithDefaultStepArgsAgainst - sparse_test: cuSparseTest, SparseObjectTest - random_lax_test: DistributionsTest::testOrthogonal{1,7,8,9} - pallas_test: PallasCallInterpretTest::test_matmul (various shapes/dtypes) - lax_numpy_test: testPolyMul1 - gpu_ops_test: FusedAttentionInterpretTest::test_fused_attention_bwd{4,6,7,8} (change the sample test data) - export_harnesses_multi_platform_test: PrimitiveTest::test_prim_qr_multi_array_shape (various) - ann_test: pass for whole file --- tests/export_harnesses_multi_platform_test.py | 6 ------ tests/lax_control_flow_test.py | 3 --- tests/lax_numpy_test.py | 3 --- tests/pallas/gpu_ops_test.py | 17 ++++++----------- tests/pallas/pallas_test.py | 3 --- tests/random_lax_test.py | 2 -- tests/scipy_signal_test.py | 3 --- 7 files changed, 6 insertions(+), 31 deletions(-) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 96a93e736a7b..3fb3a596a7f3 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -79,9 +79,6 @@ def setUp(self): "operator to work around missing XLA support for pair-reductions") ) def test_prim(self, harness: test_harnesses.Harness): - if jtu.is_device_rocm() and "multi_array_shape" in harness.fullname: - self.skipTest("Skip on ROCm: test_prim_multi_array_shape") - if "eigh_" in harness.fullname: self.skipTest("Eigenvalues are sorted and it is not correct to compare " "decompositions for equality.") @@ -194,9 +191,6 @@ def test_all_gather(self, *, dtype): self.export_and_compare_to_native(f, x) def test_random_with_threefry_gpu_kernel_lowering(self): - if jtu.is_device_rocm() and jtu.get_rocm_version() > (6, 5): - self.skipTest("Skip on ROCm: test_random_with_threefry_gpu_kernel_lowering") - # On GPU we use a custom call for threefry2x32 with config.threefry_gpu_kernel_lowering(True): def f(x): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index dd00d002c52d..5dd98eb547c3 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2599,9 +2599,6 @@ def testAssociativeScanOfBools(self): def testAssociativeScanSolvingRegressionTest(self, shape): # This test checks that the batching rule doesn't raise for a batch # sensitive function (solve). - if jtu.is_device_rocm(): - self.skipTest("Skip on ROCm: testAssociativeScanSolvingRegressionTest") - ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0) vs = np.ones((shape, 2)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 849403e56cf1..064bef068411 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2585,9 +2585,6 @@ def testDiagFlat(self, shape, dtype, k): a2_shape=one_dim_array_shapes, ) def testPolyMul(self, a1_shape, a2_shape, dtype): - if jtu.is_device_rocm() and str(self).split()[0] == "testPolyMul1": - self.skipTest("Skip on ROCm: testPolyMul") - rng = jtu.rand_default(self.rng()) np_fun = lambda arg1, arg2: np.polymul(arg1, arg2) jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True) diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index e9452cd750f6..ed3e4772e315 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -235,12 +235,12 @@ def impl(q, k, v): head_dim=(32, 64, 128,), block_sizes=( ( - ("block_q", 64), - ("block_k", 64), - ("block_q_dkv", 16), - ("block_kv_dkv", 16), - ("block_q_dq", 16), - ("block_kv_dq", 64), + ("block_q", 128), + ("block_k", 128), + ("block_q_dkv", 128), + ("block_kv_dkv", 128), + ("block_q_dq", 128), + ("block_kv_dq", 128), ), ( ("block_q", 32), @@ -291,11 +291,6 @@ def test_fused_attention_bwd( causal, use_segment_ids, ): - test_name = str(self).split()[0] - skip_suffix_list = [4, 6, 7, 8, 9] - if jtu.is_device_rocm() and self.INTERPRET and any(test_name.endswith(str(suffix)) for suffix in skip_suffix_list): - self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd[4, 6, 7, 8, 9]") - k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 4551c1a090d5..ed371de73012 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -549,9 +549,6 @@ def index(x_ref, idx_ref, o_ref): def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") - if jtu.is_device_rocm() and self.INTERPRET: - self.skipTest("Skip on ROCm: test_matmul") - k1, k2 = random.split(random.key(0)) x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 3e98e0aab6c2..7ec3263144e8 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -740,8 +740,6 @@ def testLogistic(self, dtype): ) @jax.default_matmul_precision("float32") def testOrthogonal(self, n, shape, dtype, m): - if jtu.is_device_rocm() and not (n == 0 or m == 0): - self.skipTest("Skip on ROCm: testOrthogonal[1-3, 8-9]") if m is None: m = n diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 98ae01977afd..11923257a9dd 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -350,9 +350,6 @@ def osp_fun(x): def testWelchWithDefaultStepArgsAgainstNumpy( self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, use_window, timeaxis): - if jtu.is_device_rocm() and (not use_nperseg or (use_nperseg and use_noverlap and use_window)): - raise unittest.SkipTest("Skip on ROCm: testWelchWithDefaultStepArgsAgainstNumpy[1,2,3,7,8]") - if tuple(shape) == (2, 3, 389, 5) and nperseg == 17 and noverlap == 13: raise unittest.SkipTest("Test fails for these inputs") kwargs = {'axis': timeaxis}