@@ -301,6 +301,50 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
301301 x = jnp .arange (256 ).astype (jnp .float32 )
302302 np .testing .assert_array_equal (kernel (x )[indexer ], x [indexer ] + 1.0 )
303303
304+ def test_ref_with_multiple_indexers (self ):
305+ x = jax .random .uniform (jax .random .key (0 ), (2 , 64 , 64 ))
306+ @functools .partial (
307+ pl .pallas_call ,
308+ out_shape = jax .ShapeDtypeStruct ([64 , 64 ], jnp .float32 ),
309+ in_specs = (pl .BlockSpec (memory_space = plgpu .GMEM ),),
310+ scratch_shapes = [
311+ plgpu .SMEM (x .shape , jnp .float32 ),
312+ plgpu .Barrier (num_arrivals = 1 ),
313+ ],
314+ )
315+ def extract_x0 (x_ref_gmem , o_ref , scratch_ref , barrier_ref ):
316+ plgpu .copy_gmem_to_smem (x_ref_gmem , scratch_ref , barrier_ref )
317+ plgpu .barrier_wait (barrier_ref )
318+ x_sliced = scratch_ref .at [0 , :, :] # shape=(64, 64)
319+ o_ref [pl .ds (0 , 32 ), :] = x_sliced [pl .ds (0 , 32 ), :]
320+ o_ref [pl .ds (32 , 32 ), :] = x_sliced [pl .ds (32 , 32 ), :]
321+ np .testing .assert_array_equal (extract_x0 (x ), x [0 ])
322+
323+ def test_smem_multiple_indexers_with_transforms (self ):
324+ x = jnp .arange (512 * 512 ).reshape (512 , 512 )
325+ @functools .partial (
326+ pl .pallas_call ,
327+ grid = (4 , 4 ),
328+ out_shape = jax .ShapeDtypeStruct ((256 , 128 ), jnp .int32 ),
329+ in_specs = (plgpu .GPUBlockSpec (
330+ block_shape = (128 , 128 ),
331+ index_map = lambda i , j : (i , j ),
332+ memory_space = plgpu .SMEM ,
333+ transforms = (plgpu .TilingTransform ((64 , 32 )),
334+ plgpu .SwizzleTransform (128 ))),),
335+ out_specs = (plgpu .GPUBlockSpec (
336+ block_shape = (64 , 32 ),
337+ index_map = lambda i , j : (i , j ),
338+ memory_space = plgpu .SMEM ,)),
339+ )
340+ def kernel (x_ref , o_ref ):
341+ x_sliced = x_ref .at [0 :64 , 32 :96 ].at [:, 0 :32 ] # get x_ref[0:64, 32:64]
342+ o_ref [...] = x_sliced [...]
343+ ref = jnp .concatenate ([x [blk :blk + 64 , :] for blk in range (0 , 512 , 128 )])
344+ ref = jnp .concatenate (
345+ [ref [:, blk + 32 :blk + 64 ] for blk in range (0 , 512 , 128 )], axis = 1 )
346+ np .testing .assert_array_equal (kernel (x ), ref )
347+
304348 @parameterized .product (indexer = [0 , 1 , 2 , 3 ])
305349 def test_copy_gmem_to_smem_with_indexed_barrier (self , indexer ):
306350 @functools .partial (
0 commit comments