Skip to content

Commit 9ebf41a

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:sc] Do not skip test_barrier_via_pallas_call
PiperOrigin-RevId: 834773228
1 parent b656500 commit 9ebf41a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,8 +1307,8 @@ def _(i):
13071307
np.testing.assert_array_equal(kernel(), expected)
13081308

13091309
def test_barrier_via_pallas_call(self):
1310-
# TODO(slebedev): Fix the IR and re-enable the test.
1311-
self.skipTest("Failing at MLIR verification time")
1310+
if not jtu.if_cloud_tpu_at_least(2025, 11, 22):
1311+
self.skipTest("Test requires a newer libtpu")
13121312

13131313
mesh = plsc.VectorSubcoreMesh(
13141314
core_axis_name="core", subcore_axis_name="subcore", num_cores=1
@@ -1325,11 +1325,11 @@ def test_barrier_via_pallas_call(self):
13251325
shape=(mesh.num_subcores, vec_dim), dtype=jnp.uint32
13261326
),
13271327
out_specs=pl.BlockSpec((1, vec_dim), lambda i: (i, 0)),
1328-
scratch_shapes=dict(
1329-
shared_ref=pltpu.VMEM_SHARED(
1328+
scratch_shapes=(
1329+
pltpu.VMEM_SHARED(
13301330
(mesh.num_subcores, vec_dim), jnp.uint32
13311331
),
1332-
vmem_ref=pltpu.VMEM((vec_dim,), jnp.uint32),
1332+
pltpu.VMEM((vec_dim,), jnp.uint32),
13331333
),
13341334
)
13351335
def kernel(o_ref, shared_ref, vmem_ref):

0 commit comments

Comments
 (0)