Skip to content

Commit f18f62a

Browse files
chsiggGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Skip small tile sizes for sparse gemms on Ampere as well. Enable the JAX test again that has been failing.
PiperOrigin-RevId: 695360850
1 parent 8a7bf2e commit f18f62a

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

tests/sparse_nm_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ def setUp(self):
4747
)
4848
@jtu.run_on_devices("gpu")
4949
def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx):
50-
if not jtu.is_cuda_compute_capability_at_least("9.0"):
51-
self.skipTest("Skipping test on Ampere because of bug b/377940729")
52-
5350
# Build keyword arguments
5451
kwargs = {
5552
"dimension_numbers": (((1,), (1,)), (tuple(), tuple())),
@@ -96,9 +93,6 @@ def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx):
9693
)
9794
@jtu.run_on_devices("gpu")
9895
def test_types(self, lhs_type, rhs_type, output_type):
99-
if not jtu.is_cuda_compute_capability_at_least("9.0"):
100-
self.skipTest("Skipping test on Ampere because of bug b/377940729")
101-
10296
tile_m, tile_n, tile_k = 64, 32, 128
10397

10498
# Build input data

0 commit comments

Comments
 (0)