Skip to content

Commit a5c5ac7

Browse files
gulsumgudukbayRuturaj4
authored andcommitted
unskip tests in tests/lax_numpy_test and tests/pallas/gpy_ops_test.py and fix decorator order in tests/linalg_test
1 parent 1f85d7d commit a5c5ac7

File tree

3 files changed

+2
-4
lines changed

3 files changed

+2
-4
lines changed

tests/lax_numpy_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1562,7 +1562,6 @@ def testTrimZerosNotOneDArray(self):
15621562
)
15631563
@jax.default_matmul_precision("float32")
15641564
def testPoly(self, a_shape, dtype, rank):
1565-
self.skipTest("Skip Poly tests on ROCm")
15661565
if dtype in (np.float16, jnp.bfloat16, np.int16):
15671566
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
15681567
elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]):

tests/linalg_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def testEigvals(self, shape, dtype):
361361

362362
@jtu.run_on_devices("cpu", "gpu")
363363
def testEigvalsInf(self):
364-
self.skipTest("Skip test on ROCm")
364+
#self.skipTest("Skip test on ROCm")
365365
# https://github.com/jax-ml/jax/issues/2661
366366
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
367367
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
@@ -2027,12 +2027,12 @@ def testSqrtmEdgeCase(self, diag, expected, dtype):
20272027

20282028
self.assertAllClose(root, expected, check_dtypes=False)
20292029

2030-
@jtu.ignore_warning(category=FutureWarning, message="Don't treat future SciPy warning as error")
20312030
@jtu.sample_product(
20322031
cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)],
20332032
cdtype=float_types + complex_types,
20342033
rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)],
20352034
rdtype=float_types + complex_types + int_types)
2035+
@jtu.ignore_warning(category=FutureWarning, message="Don't treat future SciPy warning as error")
20362036
def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype):
20372037
if ((rdtype in [np.float64, np.complex128]
20382038
or cdtype in [np.float64, np.complex128])

tests/pallas/gpu_ops_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def test_fused_attention_bwd(
239239
causal,
240240
use_segment_ids,
241241
):
242-
self.skipTest("Skip tests on ROCm")
243242
k1, k2, k3 = random.split(random.key(0), 3)
244243
q = random.normal(
245244
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16

0 commit comments

Comments
 (0)