Skip to content

Commit 12b45b3

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] emit_pipeline no longer ignores transforms
PiperOrigin-RevId: 702726201
1 parent 2ac2692 commit 12b45b3

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@ class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):
138138
def to_gpu_transform(self) -> mgpu.MemRefTransform:
139139
pass
140140

141+
def batch(self, leading_rank: int):
142+
"""Returns a transform that accepts a ref with the extra `leading_rank` dims.
143+
144+
The returned transform should leave the leading dimensions unchanged and
145+
only apply to the suffix of the shape.
146+
"""
147+
raise NotImplementedError
148+
141149
def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
142150
return aval.update(
143151
shape=self.to_gpu_transform().transform_shape(aval.shape)
@@ -161,6 +169,9 @@ def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
161169
ref, transforms=(*ref.transforms, UntileRef(self.tiling))
162170
)
163171

172+
def batch(self, leading_rank: int):
173+
return self
174+
164175
def to_gpu_transform(self) -> mgpu.MemRefTransform:
165176
return mgpu.TileTransform(self.tiling)
166177

@@ -228,6 +239,11 @@ def __post_init__(self):
228239
if set(self.permutation) != set(range(len(self.permutation))):
229240
raise ValueError(f"Permutation {self.permutation} is not a permutation.")
230241

242+
def batch(self, leading_rank: int):
243+
return TransposeTransform(
244+
(*range(leading_rank), *(d + leading_rank for d in self.permutation))
245+
)
246+
231247
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
232248
return dataclasses.replace(
233249
ref,
@@ -304,6 +320,9 @@ def __post_init__(self):
304320
" accepted."
305321
)
306322

323+
def batch(self, leading_rank: int):
324+
return self
325+
307326
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
308327
return dataclasses.replace(
309328
ref, transforms=(*ref.transforms, UnswizzleRef(self.swizzle))

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,13 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
195195
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
196196
in_smem_refs, out_smem_refs = util.split_list(
197197
[
198-
gpu_core.SMEM((max_concurrent_steps, *spec.block_shape), ref.dtype) # type: ignore
198+
gpu_core.SMEM(
199+
(max_concurrent_steps, *spec.block_shape), # type: ignore
200+
ref.dtype,
201+
transforms=tuple(
202+
t.batch(1) for t in getattr(spec, "transforms", ())
203+
),
204+
)
199205
if _in_smem(spec)
200206
else None
201207
for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs)

tests/pallas/mosaic_gpu_test.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,6 @@ def kernel(x_ref, o_ref):
820820
x = jnp.arange(256)
821821
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256]))
822822

823-
824823
@parameterized.parameters(jnp.float16, jnp.float32)
825824
def test_wgmma(self, dtype):
826825
self.skip_unless_sm90a()
@@ -1233,23 +1232,37 @@ def body(step, _):
12331232
)
12341233
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
12351234

1236-
def test_emit(self):
1235+
@parameterized.parameters(
1236+
((),),
1237+
((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),),
1238+
)
1239+
def test_emit(self, transforms):
12371240
num_steps = 4
12381241

12391242
def kernel(x_gmem, o_gmem):
12401243
plgpu.emit_pipeline(
12411244
kernel_body,
1242-
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
1243-
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
1245+
in_specs=[
1246+
plgpu.GPUBlockSpec(
1247+
(64, 64), lambda i: (0, i), transforms=transforms
1248+
)
1249+
],
1250+
out_specs=[
1251+
plgpu.GPUBlockSpec(
1252+
(64, 64), lambda i: (0, i), transforms=transforms
1253+
)
1254+
],
12441255
grid=(num_steps,),
12451256
max_concurrent_steps=2,
12461257
)(x_gmem, o_gmem)
12471258

12481259
def kernel_body(x_smem, o_smem):
1260+
# +1 for the indexing done by ``emit_pipeline`.
1261+
self.assertLen(x_smem.transforms, len(transforms) + 1)
12491262
o_smem[...] = x_smem[...] + 1.0
12501263

1251-
x = jnp.arange(32 * num_steps * 16)
1252-
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
1264+
x = jnp.arange(64 * num_steps * 64)
1265+
x = x.reshape(-1, num_steps * 64).astype(jnp.float32)
12531266
kernel_fn = pl.pallas_call(
12541267
kernel,
12551268
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],

0 commit comments

Comments
 (0)