Skip to content

Commit d99a637

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Allow multiple indexing on refs
PiperOrigin-RevId: 713355813
1 parent 3848f0d commit d99a637

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,19 +1012,23 @@ def _handle_indexing(
10121012
]
10131013
if not indexer_idxs:
10141014
return ref, transforms
1015-
if len(indexer_idxs) > 1:
1016-
raise NotImplementedError("Only one level of indexing supported.")
1017-
[indexer_idx] = indexer_idxs
1018-
indexer = cast(indexing.NDIndexer, transforms[indexer_idx])
1019-
if indexer.int_indexer_shape:
1020-
raise NotImplementedError("int_indexer_shape non-empty")
1021-
indices = _ndindexer_indices(indexer)
1022-
new_transforms_rev = []
1023-
for t in reversed(transforms[:indexer_idx]):
1024-
indices, new_t = t.untransform_index(indices)
1025-
new_transforms_rev.append(new_t)
1026-
new_transforms = [*reversed(new_transforms_rev), *transforms[indexer_idx + 1:]]
1027-
return mgpu.memref_slice(ref, indices), new_transforms
1015+
sliced_ref = ref
1016+
new_transforms = []
1017+
for t in transforms:
1018+
if not isinstance(t, indexing.NDIndexer):
1019+
new_transforms.append(t)
1020+
continue
1021+
indexer = cast(indexing.NDIndexer, t)
1022+
if indexer.int_indexer_shape:
1023+
raise NotImplementedError("int_indexer_shape non-empty")
1024+
indices = _ndindexer_indices(indexer)
1025+
new_transforms_rev = []
1026+
for t in reversed(new_transforms):
1027+
indices, new_t = t.untransform_index(indices)
1028+
new_transforms_rev.append(new_t)
1029+
sliced_ref = mgpu.memref_slice(sliced_ref, indices)
1030+
new_transforms = list(reversed(new_transforms_rev))
1031+
return sliced_ref, new_transforms
10281032

10291033

10301034
def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]:

tests/pallas/mosaic_gpu_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,50 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
301301
x = jnp.arange(256).astype(jnp.float32)
302302
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
303303

304+
def test_ref_with_multiple_indexers(self):
305+
x = jax.random.uniform(jax.random.key(0), (2, 64, 64))
306+
@functools.partial(
307+
pl.pallas_call,
308+
out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32),
309+
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
310+
scratch_shapes=[
311+
plgpu.SMEM(x.shape, jnp.float32),
312+
plgpu.Barrier(num_arrivals=1),
313+
],
314+
)
315+
def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
316+
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref)
317+
plgpu.barrier_wait(barrier_ref)
318+
x_sliced = scratch_ref.at[0, :, :] # shape=(64, 64)
319+
o_ref[pl.ds(0, 32), :] = x_sliced[pl.ds(0, 32), :]
320+
o_ref[pl.ds(32, 32), :] = x_sliced[pl.ds(32, 32), :]
321+
np.testing.assert_array_equal(extract_x0(x), x[0])
322+
323+
def test_smem_multiple_indexers_with_transforms(self):
324+
x = jnp.arange(512 * 512).reshape(512, 512)
325+
@functools.partial(
326+
pl.pallas_call,
327+
grid=(4, 4),
328+
out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32),
329+
in_specs=(plgpu.GPUBlockSpec(
330+
block_shape=(128, 128),
331+
index_map=lambda i, j: (i, j),
332+
memory_space=plgpu.SMEM,
333+
transforms=(plgpu.TilingTransform((64, 32)),
334+
plgpu.SwizzleTransform(128))),),
335+
out_specs=(plgpu.GPUBlockSpec(
336+
block_shape=(64, 32),
337+
index_map=lambda i, j: (i, j),
338+
memory_space=plgpu.SMEM,)),
339+
)
340+
def kernel(x_ref, o_ref):
341+
x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64]
342+
o_ref[...] = x_sliced[...]
343+
ref = jnp.concatenate([x[blk:blk+64, :] for blk in range(0, 512, 128)])
344+
ref = jnp.concatenate(
345+
[ref[:, blk+32:blk+64] for blk in range(0, 512, 128)], axis=1)
346+
np.testing.assert_array_equal(kernel(x), ref)
347+
304348
@parameterized.product(indexer=[0, 1, 2, 3])
305349
def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer):
306350
@functools.partial(

0 commit comments

Comments
 (0)