Skip to content

Commit c07c3a1

Browse files
[0.6.0-UT] Skipping aborted tests (HIP runtime issue) (#530)
1 parent 68d7228 commit c07c3a1

File tree

5 files changed

+22
-0
lines changed

5 files changed

+22
-0
lines changed

tests/image_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def testResizeUp(self, dtype, image_shape, target_shape, method):
160160
)
161161
def testResizeGradients(self, dtype, image_shape, target_shape, method,
162162
antialias):
163+
if jtu.is_device_rocm():
164+
self.skipTest("Skip on ROCm: testResizeGradients. Test aborts due to HIP runtime issue")
165+
163166
rng = jtu.rand_default(self.rng())
164167
args_maker = lambda: (rng(image_shape, dtype),)
165168
jax_fn = partial(image.resize, shape=target_shape, method=method,

tests/lax_numpy_reducers_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,9 @@ def testNanStdGrad(self):
671671
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
672672
@jax.default_matmul_precision('float32')
673673
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
674+
if jtu.is_device_rocm():
675+
self.skipTest("Skip on ROCm: testCov. Test aborts due to HIP runtime issue")
676+
674677
rng = jtu.rand_default(self.rng())
675678
wrng = jtu.rand_positive(self.rng())
676679
wdtype = np.real(dtype(0)).dtype

tests/lax_scipy_sparse_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def args_maker():
134134
dtype=float_types + complex_types,
135135
)
136136
def test_cg_as_solve(self, shape, dtype):
137+
if jtu.is_device_rocm():
138+
self.skipTest("Skip on ROCm: test_cg_as_solve. Test aborts due to HIP runtime issue")
137139

138140
rng = jtu.rand_default(self.rng())
139141
a = rng(shape, dtype)

tests/linalg_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,8 @@ def testLuOnZeroMatrix(self, lu):
14791479
dtype=float_types + complex_types,
14801480
)
14811481
def testLuGrad(self, shape, dtype):
1482+
if jtu.is_device_rocm():
1483+
self.skipTest("Skip on ROCm: testLuGrad. Test aborts due to HIP runtime issue")
14821484
rng = jtu.rand_default(self.rng())
14831485
a = rng(shape, dtype)
14841486
lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
@@ -1709,6 +1711,9 @@ def testTriangularSolveSingularBatched(self):
17091711
dtype=int_types + float_types + complex_types
17101712
)
17111713
def testExpm(self, n, batch_size, dtype):
1714+
if jtu.is_device_rocm():
1715+
self.skipTest("Skip on ROCm: testExpm. Test aborts due to HIP runtime issue")
1716+
17121717
if (jtu.test_device_matches(["cuda"]) and
17131718
_is_required_cuda_version_satisfied(12000)):
17141719
self.skipTest("Triggers a bug in cuda-12 b/287345077")
@@ -1861,6 +1866,8 @@ def sp_func(a):
18611866
dtype=float_types + complex_types,
18621867
)
18631868
def testIssue2131(self, n, dtype):
1869+
if jtu.is_device_rocm():
1870+
self.skipTest("Skip on ROCm: testIssue2131. Test aborts due to HIP runtime issue")
18641871
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
18651872
osp_fun = lambda a: osp.linalg.expm(a)
18661873
jsp_fun = lambda a: jsp.linalg.expm(a)
@@ -1896,6 +1903,8 @@ def args_maker():
18961903
dtype=float_types + complex_types,
18971904
)
18981905
def testExpmFrechet(self, n, dtype):
1906+
if jtu.is_device_rocm():
1907+
self.skipTest("Skip on ROCm: testExpmFrechet. Test aborts due to HIP runtime issue")
18991908
rng = jtu.rand_small(self.rng())
19001909
if dtype == np.float64 or dtype == np.complex128:
19011910
target_norms = [1.0e-2, 2.0e-1, 9.0e-01, 2.0, 3.0]
@@ -1934,6 +1943,9 @@ def args_maker():
19341943
dtype=float_types + complex_types,
19351944
)
19361945
def testExpmGrad(self, n, dtype):
1946+
if jtu.is_device_rocm():
1947+
self.skipTest("Skip on ROCm: testExpmGrad. Test aborts due to HIP runtime issue")
1948+
19371949
rng = jtu.rand_small(self.rng())
19381950
a = rng((n, n), dtype)
19391951
if dtype == np.float64 or dtype == np.complex128:

tests/qdwh_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def _testPolarDecomposition(self, a, u, h, tol):
6565

6666
def _testQdwh(self, a, dynamic_shape=None):
6767
"""Computes the polar decomposition and tests its basic properties."""
68+
if jtu.is_device_rocm():
69+
self.skipTest("Skip on ROCm: testQdwh. Test aborts due to HIP runtime issue")
6870
eps = jnp.finfo(a.dtype).eps
6971
u, h, iters, conv = qdwh.qdwh(a, dynamic_shape=dynamic_shape)
7072
tol = 13 * eps

0 commit comments

Comments
 (0)