@@ -77,12 +77,12 @@ def uses_tc_tiling(self):
7777 "xla_tpu_use_tc_device_shape_on_sc" , "false"
7878 ) == "true"
7979
80- def skip_if_tc_tiling (self ):
80+ def skip_if_tc_tiling (self , reason : str = "" ):
8181 use_tc_tiling = self .COMPILER_OPTIONS .get (
8282 "xla_tpu_use_tc_device_shape_on_sc" , "false"
8383 )
8484 if use_tc_tiling == "true" :
85- self .skipTest ("TC tiling is not supported" )
85+ self .skipTest (f "TC tiling is not supported. { reason } " )
8686
8787
8888class TCTilingMixin ():
@@ -470,7 +470,7 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
470470 np .testing .assert_array_equal (kernel (x , indices ), x [indices ])
471471
472472 def test_gather_1d_with_indexing (self ):
473- self .skip_if_tc_tiling ()
473+ self .skip_if_tc_tiling ("Small 1d gather does not work on TC tiling." )
474474 x = jnp .arange (4 * 4 * 8 ).reshape (4 , 4 , 8 )
475475 indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (8 ))
476476
@@ -486,6 +486,22 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
486486
487487 np .testing .assert_array_equal (kernel (x , indices ), x [1 , 2 , indices ])
488488
489+ def test_gather_2d_with_indexing (self ):
490+ x = jnp .arange (4 * 16 * 128 ).reshape (4 , 16 , 128 )
491+ indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (8 ))
492+
493+ @self .vector_subcore_kernel (
494+ out_shape = jax .ShapeDtypeStruct (shape = (8 , 128 ,), dtype = jnp .int32 ),
495+ in_specs = (
496+ pl .BlockSpec (memory_space = pltpu .HBM ),
497+ pl .BlockSpec (memory_space = pltpu .VMEM ),
498+ ),
499+ )
500+ def kernel (x_hbm_ref , indices_ref , o_ref ):
501+ pltpu .sync_copy (x_hbm_ref .at [1 , pl .ds (8 , 8 ), :].at [indices_ref ], o_ref )
502+
503+ np .testing .assert_array_equal (kernel (x , indices ), x [1 , 8 :][indices ])
504+
489505 def test_gather_1d_with_indexed_ref (self ):
490506 x = jnp .arange (16 )
491507 indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (16 ))
0 commit comments