Skip to content

Commit 104ff91

Browse files
IvyZXGoogle-ML-Automation
authored andcommitted
Enable gather_with_index test on tc tiling. Also add a slicing to the test.
PiperOrigin-RevId: 834957962
1 parent c833522 commit 104ff91

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,22 +416,21 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
416416

417417
np.testing.assert_array_equal(kernel(x, indices), x[indices])
418418

419-
def test_gather_1d_with_indexing(self):
420-
self.skip_if_tc_tiling()
421-
x = jnp.arange(4 * 4 * 8).reshape(4, 4, 8)
419+
def test_gather_2d_with_indexing(self):
420+
x = jnp.arange(4 * 16 * 128).reshape(4, 16, 128)
422421
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))
423422

424423
@self.vector_subcore_kernel(
425-
out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32),
424+
out_shape=jax.ShapeDtypeStruct(shape=(8, 128,), dtype=jnp.int32),
426425
in_specs=(
427426
pl.BlockSpec(memory_space=pltpu.HBM),
428427
pl.BlockSpec(memory_space=pltpu.VMEM),
429428
),
430429
)
431430
def kernel(x_hbm_ref, indices_ref, o_ref):
432-
pltpu.sync_copy(x_hbm_ref.at[1, 2].at[indices_ref], o_ref)
431+
pltpu.sync_copy(x_hbm_ref.at[1, pl.ds(8, 8), :].at[indices_ref], o_ref)
433432

434-
np.testing.assert_array_equal(kernel(x, indices), x[1, 2, indices])
433+
np.testing.assert_array_equal(kernel(x, indices), x[1, 8:][indices])
435434

436435
def test_gather_1d_with_indexed_ref(self):
437436
x = jnp.arange(16)

0 commit comments

Comments
 (0)