@@ -416,22 +416,21 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
416416
417417 np .testing .assert_array_equal (kernel (x , indices ), x [indices ])
418418
419- def test_gather_1d_with_indexing (self ):
420- self .skip_if_tc_tiling ()
421- x = jnp .arange (4 * 4 * 8 ).reshape (4 , 4 , 8 )
419+ def test_gather_2d_with_indexing (self ):
420+ x = jnp .arange (4 * 16 * 128 ).reshape (4 , 16 , 128 )
422421 indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (8 ))
423422
424423 @self .vector_subcore_kernel (
425- out_shape = jax .ShapeDtypeStruct (shape = (8 ,), dtype = jnp .int32 ),
424+ out_shape = jax .ShapeDtypeStruct (shape = (8 , 128 , ), dtype = jnp .int32 ),
426425 in_specs = (
427426 pl .BlockSpec (memory_space = pltpu .HBM ),
428427 pl .BlockSpec (memory_space = pltpu .VMEM ),
429428 ),
430429 )
431430 def kernel (x_hbm_ref , indices_ref , o_ref ):
432- pltpu .sync_copy (x_hbm_ref .at [1 , 2 ].at [indices_ref ], o_ref )
431+ pltpu .sync_copy (x_hbm_ref .at [1 , pl . ds ( 8 , 8 ), : ].at [indices_ref ], o_ref )
433432
434- np .testing .assert_array_equal (kernel (x , indices ), x [1 , 2 , indices ])
433+ np .testing .assert_array_equal (kernel (x , indices ), x [1 , 8 :][ indices ])
435434
436435 def test_gather_1d_with_indexed_ref (self ):
437436 x = jnp .arange (16 )
0 commit comments