From 0df62464e488351d288c685b3c2c12849f0e3301 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 20 Nov 2025 16:32:39 -0800 Subject: [PATCH] Enable gather_with_index test on tc tiling. Also add a slicing to the test. PiperOrigin-RevId: 834957962 --- tests/pallas/tpu_sparsecore_pallas_test.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 8f6717024229..2a2c17e8f270 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -77,12 +77,12 @@ def uses_tc_tiling(self): "xla_tpu_use_tc_device_shape_on_sc", "false" ) == "true" - def skip_if_tc_tiling(self): + def skip_if_tc_tiling(self, reason: str = ""): use_tc_tiling = self.COMPILER_OPTIONS.get( "xla_tpu_use_tc_device_shape_on_sc", "false" ) if use_tc_tiling == "true": - self.skipTest("TC tiling is not supported") + self.skipTest(f"TC tiling is not supported. {reason}") class TCTilingMixin(): @@ -470,7 +470,7 @@ def kernel(x_hbm_ref, indices_ref, o_ref): np.testing.assert_array_equal(kernel(x, indices), x[indices]) def test_gather_1d_with_indexing(self): - self.skip_if_tc_tiling() + self.skip_if_tc_tiling("Small 1d gather does not work on TC tiling.") x = jnp.arange(4 * 4 * 8).reshape(4, 4, 8) indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) @@ -486,6 +486,22 @@ def kernel(x_hbm_ref, indices_ref, o_ref): np.testing.assert_array_equal(kernel(x, indices), x[1, 2, indices]) + def test_gather_2d_with_indexing(self): + x = jnp.arange(4 * 16 * 128).reshape(4, 16, 128) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8, 128,), dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[1, pl.ds(8, 8), :].at[indices_ref], o_ref) + + np.testing.assert_array_equal(kernel(x, indices), x[1, 8:][indices]) + def test_gather_1d_with_indexed_ref(self): x = jnp.arange(16) indices = jax.random.permutation(jax.random.key(42), jnp.arange(16))