Skip to content

Commit f08801b

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Allow indexing to appear anywhere in the list of transforms
We only need to exchange the transforms preceding the indexer, while the rest can remain unmodified. PiperOrigin-RevId: 688112088
1 parent a2bc8c2 commit f08801b

File tree

6 files changed

+51
-17
lines changed

6 files changed

+51
-17
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ def tree_unflatten(cls, metadata, arrays):
226226

227227

228228
def transpose_ref(
229-
ref: pallas_core.TransformedRef | pallas_core.AbstractMemoryRef,
229+
ref: pallas_core.TransformedRef | Any,
230230
permutation: tuple[int, ...],
231231
) -> pallas_core.TransformedRef:
232232
if not isinstance(ref, pallas_core.TransformedRef):
233-
if not isinstance(ref, pallas_core.AbstractMemoryRef):
233+
if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef):
234234
raise TypeError("ref must be a reference")
235235
ref = pallas_core.TransformedRef(ref, transforms=())
236236
return pallas_core.TransformedRef(

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -871,22 +871,24 @@ def _handle_indexing(
871871
) -> tuple[ir.Value, Sequence[gpu_core.Transform]]:
872872
if not transforms:
873873
pass
874-
if not any(isinstance(t, indexing.NDIndexer) for t in transforms):
874+
indexer_idxs = [
875+
i for i, t in enumerate(transforms) if isinstance(t, indexing.NDIndexer)
876+
]
877+
if not indexer_idxs:
875878
return ref, transforms
876-
if any(
877-
isinstance(t, indexing.NDIndexer) for t in transforms[:-1]
878-
) or not isinstance(transforms[-1], indexing.NDIndexer):
879+
if len(indexer_idxs) > 1:
879880
raise NotImplementedError("Only one level of indexing supported.")
880-
881-
indexer = cast(indexing.NDIndexer, transforms[-1])
881+
[indexer_idx] = indexer_idxs
882+
indexer = cast(indexing.NDIndexer, transforms[indexer_idx])
882883
if indexer.int_indexer_shape:
883884
raise NotImplementedError("int_indexer_shape non-empty")
884885
indices = _ndindexer_indices(indexer)
885886
new_transforms_rev = []
886-
for t in reversed(transforms[:-1]):
887+
for t in reversed(transforms[:indexer_idx]):
887888
indices, new_t = t.untransform_index(indices)
888889
new_transforms_rev.append(new_t)
889-
return mgpu.memref_slice(ref, indices), new_transforms_rev[::-1]
890+
new_transforms = [*reversed(new_transforms_rev), *transforms[indexer_idx + 1:]]
891+
return mgpu.memref_slice(ref, indices), new_transforms
890892

891893

892894
def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]:

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def copy_smem_to_gmem(
107107
if dst.memory_space is not gpu_core.GMEM:
108108
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
109109
src, src_transforms = state_primitives.get_ref_and_transforms(
110-
src, None, "copy_smem_to_gmem"
110+
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
111111
)
112112
dst, dst_transforms = state_primitives.get_ref_and_transforms(
113-
dst, None, "copy_smem_to_gmem"
113+
dst, None, "copy_smem_to_gmem", force_trailing_indexer=False,
114114
)
115115
flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten(
116116
src_transforms
@@ -193,10 +193,10 @@ def copy_gmem_to_smem(
193193
if dst.memory_space is not gpu_core.SMEM:
194194
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
195195
src, src_transforms = state_primitives.get_ref_and_transforms(
196-
src, None, "copy_gmem_to_smem"
196+
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
197197
)
198198
dst, dst_transforms = state_primitives.get_ref_and_transforms(
199-
dst, None, "copy_gmem_to_smem"
199+
dst, None, "copy_gmem_to_smem", force_trailing_indexer=False,
200200
)
201201
flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten(
202202
src_transforms
@@ -205,7 +205,7 @@ def copy_gmem_to_smem(
205205
dst_transforms
206206
)
207207
barrier, barrier_transforms = state_primitives.get_ref_and_transforms(
208-
barrier, None, "copy_gmem_to_smem"
208+
barrier, None, "copy_gmem_to_smem", force_trailing_indexer=False,
209209
)
210210
flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten(
211211
barrier_transforms
@@ -284,7 +284,7 @@ def _barrier_arrive_lowering(
284284
def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None:
285285
"""Arrives at the given barrier."""
286286
barrier, transforms = state_primitives.get_ref_and_transforms(
287-
barrier, None, "barrier_arrive"
287+
barrier, None, "barrier_arrive", force_trailing_indexer=False,
288288
)
289289
flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms)
290290
barrier_arrive_p.bind(
@@ -321,7 +321,7 @@ def _barrier_wait_lowering(
321321
def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None:
322322
"""Waits on the given barrier."""
323323
barrier, transforms = state_primitives.get_ref_and_transforms(
324-
barrier, None, "barrier_wait"
324+
barrier, None, "barrier_wait", force_trailing_indexer=False,
325325
)
326326
flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms)
327327
barrier_wait_p.bind(

jax/_src/state/primitives.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def get_ref_and_transforms(
7373
ref_or_view: Any,
7474
idx: Indexer | tuple[Indexer, ...] | None,
7575
function_name: str,
76+
force_trailing_indexer: bool = True, # TODO(apaszke): Clean this up.
7677
) -> tuple[Any, tuple[Transform, ...]]:
7778
if isinstance(ref_or_view, TransformedRef):
7879
ref, transforms = ref_or_view.ref, ref_or_view.transforms
@@ -89,6 +90,8 @@ def get_ref_and_transforms(
8990
elif not isinstance(idx, tuple):
9091
idx = (idx,)
9192

93+
if not idx and not force_trailing_indexer:
94+
return ref, transforms
9295
if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer):
9396
return ref, transforms
9497
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)

jax/experimental/mosaic/gpu/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,12 @@ def async_copy(
400400
"Expected the SMEM reference to have the same shape as the"
401401
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
402402
)
403+
smem_strides, _ = smem_ref_ty.get_strides_and_offset()
404+
if smem_strides != utils.get_contiguous_strides(smem_ref_ty.shape):
405+
raise ValueError(
406+
"async_copy needs the SMEM reference to be contiguous, but got"
407+
f" strides {smem_strides} for shape {smem_ref_ty.shape}"
408+
)
403409

404410
dyn_base_indices = list(dyn_base_indices)
405411
slice_shape = list(slice_shape)

tests/pallas/mosaic_gpu_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,29 @@ def kernel(x_ref, o_ref, barrier_ref):
314314
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
315315
np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0))
316316

317+
def test_indexing_before_transpose(self):
318+
def kernel(x_ref, o_ref, barrier_ref):
319+
for i in range(2):
320+
plgpu.copy_gmem_to_smem(
321+
x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier=barrier_ref
322+
)
323+
plgpu.barrier_wait(barrier_ref)
324+
325+
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
326+
out_spec = plgpu.GPUBlockSpec(
327+
(2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM,
328+
)
329+
f = pl.pallas_call(
330+
kernel,
331+
out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32),
332+
in_specs=(in_spec,),
333+
out_specs=out_spec,
334+
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
335+
)
336+
x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128)
337+
xt = x.transpose((1, 0, 2))
338+
np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0))
339+
317340
def test_copy_gmem_to_smem_in_run_scoped(self):
318341
@functools.partial(
319342
pl.pallas_call,

0 commit comments

Comments
 (0)