Skip to content

Commit 0df6246

Browse files
IvyZXGoogle-ML-Automation
authored andcommitted
Enable gather_with_index test on tc tiling. Also add a slicing to the test.
PiperOrigin-RevId: 834957962
1 parent 1fb2ec4 commit 0df6246

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ def uses_tc_tiling(self):
7777
"xla_tpu_use_tc_device_shape_on_sc", "false"
7878
) == "true"
7979

80-
def skip_if_tc_tiling(self):
80+
def skip_if_tc_tiling(self, reason: str = ""):
8181
use_tc_tiling = self.COMPILER_OPTIONS.get(
8282
"xla_tpu_use_tc_device_shape_on_sc", "false"
8383
)
8484
if use_tc_tiling == "true":
85-
self.skipTest("TC tiling is not supported")
85+
self.skipTest(f"TC tiling is not supported. {reason}")
8686

8787

8888
class TCTilingMixin():
@@ -470,7 +470,7 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
470470
np.testing.assert_array_equal(kernel(x, indices), x[indices])
471471

472472
def test_gather_1d_with_indexing(self):
473-
self.skip_if_tc_tiling()
473+
self.skip_if_tc_tiling("Small 1d gather does not work on TC tiling.")
474474
x = jnp.arange(4 * 4 * 8).reshape(4, 4, 8)
475475
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))
476476

@@ -486,6 +486,22 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
486486

487487
np.testing.assert_array_equal(kernel(x, indices), x[1, 2, indices])
488488

489+
def test_gather_2d_with_indexing(self):
490+
x = jnp.arange(4 * 16 * 128).reshape(4, 16, 128)
491+
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))
492+
493+
@self.vector_subcore_kernel(
494+
out_shape=jax.ShapeDtypeStruct(shape=(8, 128,), dtype=jnp.int32),
495+
in_specs=(
496+
pl.BlockSpec(memory_space=pltpu.HBM),
497+
pl.BlockSpec(memory_space=pltpu.VMEM),
498+
),
499+
)
500+
def kernel(x_hbm_ref, indices_ref, o_ref):
501+
pltpu.sync_copy(x_hbm_ref.at[1, pl.ds(8, 8), :].at[indices_ref], o_ref)
502+
503+
np.testing.assert_array_equal(kernel(x, indices), x[1, 8:][indices])
504+
489505
def test_gather_1d_with_indexed_ref(self):
490506
x = jnp.arange(16)
491507
indices = jax.random.permutation(jax.random.key(42), jnp.arange(16))

0 commit comments

Comments
 (0)