@@ -1307,8 +1307,8 @@ def _(i):
13071307 np .testing .assert_array_equal (kernel (), expected )
13081308
13091309 def test_barrier_via_pallas_call (self ):
1310- # TODO(slebedev): Fix the IR and re-enable the test.
1311- self .skipTest ("Failing at MLIR verification time " )
1310+ if not jtu . if_cloud_tpu_at_least ( 2025 , 11 , 22 ):
1311+ self .skipTest ("Test requires a newer libtpu " )
13121312
13131313 mesh = plsc .VectorSubcoreMesh (
13141314 core_axis_name = "core" , subcore_axis_name = "subcore" , num_cores = 1
@@ -1325,11 +1325,11 @@ def test_barrier_via_pallas_call(self):
13251325 shape = (mesh .num_subcores , vec_dim ), dtype = jnp .uint32
13261326 ),
13271327 out_specs = pl .BlockSpec ((1 , vec_dim ), lambda i : (i , 0 )),
1328- scratch_shapes = dict (
1329- shared_ref = pltpu .VMEM_SHARED (
1328+ scratch_shapes = (
1329+ pltpu .VMEM_SHARED (
13301330 (mesh .num_subcores , vec_dim ), jnp .uint32
13311331 ),
1332- vmem_ref = pltpu .VMEM ((vec_dim ,), jnp .uint32 ),
1332+ pltpu .VMEM ((vec_dim ,), jnp .uint32 ),
13331333 ),
13341334 )
13351335 def kernel (o_ref , shared_ref , vmem_ref ):
0 commit comments