Skip to content

Commit 1efef6b

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] emit_pipeline now correctly supports BlockSpecs in GMEM
This is necessary to replace the pipelining logic in the lowering with `emit_pipeline`. PiperOrigin-RevId: 698858380
1 parent 96c0129 commit 1efef6b

File tree

2 files changed

+67
-15
lines changed

2 files changed

+67
-15
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@ class BufferedRef:
4646
spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True})
4747
is_index_invariant: bool = dataclasses.field(metadata={"static": True})
4848
gmem_ref: pallas_core.AbstractMemoryRef
49-
smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape]
49+
# ``None`` if the ref is pinned to GMEM; otherwise, has shape
50+
# [num_slots, *spec.block_shape].
51+
smem_ref: pallas_core.AbstractMemoryRef | None
52+
53+
def get_ref_for_slot(
54+
self, slot: int | jax.Array
55+
) -> pallas_core.AbstractMemoryRef:
56+
if self.smem_ref is None:
57+
return self.gmem_ref
58+
return self.smem_ref.at[slot]
5059

5160
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
5261
index_map = self.spec.index_map
@@ -59,6 +68,9 @@ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
5968
)
6069

6170
def copy_in(self, slot, grid_indices, barrier_ref):
71+
if not _in_smem(self.spec):
72+
return
73+
assert self.smem_ref is not None
6274
gmem_slices = self.compute_gmem_slice(grid_indices)
6375
gpu_primitives.copy_gmem_to_smem(
6476
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
@@ -67,6 +79,9 @@ def copy_in(self, slot, grid_indices, barrier_ref):
6779
)
6880

6981
def copy_out(self, slot, grid_indices, predicate=None):
82+
if not _in_smem(self.spec):
83+
return
84+
assert self.smem_ref is not None
7085
gmem_slices = self.compute_gmem_slice(grid_indices)
7186
gpu_primitives.copy_smem_to_gmem(
7287
self.smem_ref.at[slot],
@@ -88,8 +103,8 @@ def _uses_arguments(
88103
def _is_index_invariant(
89104
spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid
90105
) -> bool:
91-
index_map = spec.index_map
92-
assert index_map is not None
106+
if (index_map := spec.index_map) is None:
107+
return True
93108
return not any(_uses_arguments(index_map, len(grid)))
94109

95110

@@ -105,6 +120,10 @@ def _inc_grid_by_1(
105120
return tuple(reversed(next_indices))
106121

107122

123+
def _in_smem(spec: pallas_core.BlockSpec) -> bool:
124+
return spec.memory_space in (None, gpu_core.SMEM)
125+
126+
108127
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
109128
# start/size are static or dynamic. This leads to pytree structure mismatch
110129
# in the pipeline body. So, we define a different ``Slice`` class below.
@@ -166,6 +185,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
166185
if any(
167186
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
168187
for idx in range(1, len(grid) + 1)
188+
if spec.block_shape is not None
169189
):
170190
raise NotImplementedError(
171191
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
@@ -174,14 +194,12 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
174194

175195
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
176196
in_smem_refs, out_smem_refs = util.split_list(
177-
map(
178-
lambda spec, ref: gpu_core.SMEM(
179-
(max_concurrent_steps, *spec.block_shape), # type: ignore
180-
ref.dtype,
181-
),
182-
it.chain(in_specs, out_specs),
183-
gmem_refs,
184-
),
197+
[
198+
gpu_core.SMEM((max_concurrent_steps, *spec.block_shape), ref.dtype) # type: ignore
199+
if _in_smem(spec)
200+
else None
201+
for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs)
202+
],
185203
[len(in_specs)],
186204
)
187205
return pl.run_scoped(
@@ -194,7 +212,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
194212
out_smem_refs=out_smem_refs,
195213
barrier_ref=gpu_core.Barrier(
196214
# TODO(slebedev): Change this to arrive only once.
197-
len(in_specs),
215+
sum(map(_in_smem, in_specs)),
198216
num_barriers=max_concurrent_steps,
199217
),
200218
)
@@ -233,9 +251,10 @@ def loop_body(step, carry):
233251
)
234252

235253
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
236-
body(
237-
*(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs))
238-
)
254+
body(*(
255+
bref.get_ref_for_slot(slot)
256+
for bref in it.chain(in_brefs, out_brefs)
257+
))
239258

240259
if not all(bref.is_index_invariant for bref in out_brefs):
241260
gpu_primitives.commit_smem()

tests/pallas/mosaic_gpu_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,39 @@ def kernel_body(x_smem, o_smem):
11861186
)
11871187
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
11881188

1189+
def test_nested_emit(self):
1190+
num_steps = 4
1191+
1192+
def kernel(x_gmem, o_gmem):
1193+
plgpu.emit_pipeline(
1194+
nested_kernel,
1195+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
1196+
out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
1197+
grid=(),
1198+
)(x_gmem, o_gmem)
1199+
1200+
def nested_kernel(x_gmem, o_gmem):
1201+
plgpu.emit_pipeline(
1202+
nested_kernel_body,
1203+
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
1204+
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
1205+
grid=(num_steps,),
1206+
max_concurrent_steps=2,
1207+
)(x_gmem, o_gmem)
1208+
1209+
def nested_kernel_body(x_smem, o_smem):
1210+
o_smem[...] = x_smem[...] + 1.0
1211+
1212+
x = jnp.arange(32 * num_steps * 16)
1213+
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
1214+
kernel_fn = pl.pallas_call(
1215+
kernel,
1216+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
1217+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
1218+
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
1219+
)
1220+
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
1221+
11891222
def test_emit_with_grid_invariant_output(self):
11901223
num_steps = 4
11911224

0 commit comments

Comments
 (0)