Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions tests/pallas/tpu_sparsecore_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def uses_tc_tiling(self):
"xla_tpu_use_tc_device_shape_on_sc", "false"
) == "true"

def skip_if_tc_tiling(self):
def skip_if_tc_tiling(self, reason: str = ""):
use_tc_tiling = self.COMPILER_OPTIONS.get(
"xla_tpu_use_tc_device_shape_on_sc", "false"
)
if use_tc_tiling == "true":
self.skipTest("TC tiling is not supported")
self.skipTest(f"TC tiling is not supported. {reason}")


class TCTilingMixin():
Expand Down Expand Up @@ -470,7 +470,7 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
np.testing.assert_array_equal(kernel(x, indices), x[indices])

def test_gather_1d_with_indexing(self):
self.skip_if_tc_tiling()
self.skip_if_tc_tiling("Small 1d gather does not work on TC tiling.")
x = jnp.arange(4 * 4 * 8).reshape(4, 4, 8)
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))

Expand All @@ -486,6 +486,22 @@ def kernel(x_hbm_ref, indices_ref, o_ref):

np.testing.assert_array_equal(kernel(x, indices), x[1, 2, indices])

def test_gather_2d_with_indexing(self):
x = jnp.arange(4 * 16 * 128).reshape(4, 16, 128)
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))

@self.vector_subcore_kernel(
out_shape=jax.ShapeDtypeStruct(shape=(8, 128,), dtype=jnp.int32),
in_specs=(
pl.BlockSpec(memory_space=pltpu.HBM),
pl.BlockSpec(memory_space=pltpu.VMEM),
),
)
def kernel(x_hbm_ref, indices_ref, o_ref):
pltpu.sync_copy(x_hbm_ref.at[1, pl.ds(8, 8), :].at[indices_ref], o_ref)

np.testing.assert_array_equal(kernel(x, indices), x[1, 8:][indices])

def test_gather_1d_with_indexed_ref(self):
x = jnp.arange(16)
indices = jax.random.permutation(jax.random.key(42), jnp.arange(16))
Expand Down
Loading