Skip to content

Commit f18df8f

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Pulled delay_release into emit_pipeline
The implementation exactly matches the one we have in the lowering. PiperOrigin-RevId: 698713343
1 parent e72b449 commit f18df8f

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,20 +125,41 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
125125

126126

127127
def emit_pipeline(
128-
body,
128+
body: Callable[..., None],
129129
*,
130130
grid: pallas_core.StaticGrid,
131131
in_specs: Sequence[pallas_core.BlockSpec] = (),
132132
out_specs: Sequence[pallas_core.BlockSpec] = (),
133133
max_concurrent_steps: int = 1,
134+
delay_release: int = 0,
134135
):
135-
"""Creates a function to emit a manual pipeline within a Pallas kernel."""
136+
"""Creates a function to emit a manual pipeline within a Pallas kernel.
137+
138+
Args:
139+
body: The pipeline body.
140+
grid: The grid to use for the pipeline.
141+
in_specs: The block specs for the inputs.
142+
out_specs: The block specs for the outputs.
143+
max_concurrent_steps: The maximum number of sequential stages that are
144+
active concurrently. Defaults to 1.
145+
delay_release: The number of steps to wait before reusing the input/output
146+
references. Defaults to 0, and must be strictly smaller than
147+
``max_concurrent_steps``. Generally, you'll want to set it to 1 if you
148+
don't await the WGMMA in the body.
149+
"""
136150
num_steps = math.prod(grid)
137151

152+
if max_concurrent_steps <= delay_release:
153+
raise ValueError(
154+
"max_concurrent_steps must be greater than delay_release, but"
155+
f" {max_concurrent_steps=}, {delay_release=}"
156+
)
157+
138158
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
139-
# reduce the size of the allocated buffers below.
159+
# reduce the size of the refs allocated in SMEM.
140160
if max_concurrent_steps > num_steps:
141161
max_concurrent_steps = num_steps
162+
delay_release = 0 # No need to delay anything.
142163

143164
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
144165
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
@@ -208,7 +229,7 @@ def loop_body(step, carry):
208229
gpu_primitives.barrier_wait(barrier_ref.at[slot])
209230
# Wait for the previous output SMEM->GMEM copy to complete.
210231
gpu_primitives.wait_smem_to_gmem(
211-
max_concurrent_steps - 1, wait_read_only=True
232+
max_concurrent_steps - (1 + delay_release), wait_read_only=True
212233
)
213234

214235
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
@@ -245,10 +266,10 @@ def loop_body(step, carry):
245266
predicate=lax.bitwise_or(slices_changed, is_last_step),
246267
)
247268

248-
fetch_step = step + max_concurrent_steps
269+
fetch_step = step + (max_concurrent_steps - delay_release)
249270
fetch_slot = slot # (x + y) % y == x % y
250271
jax.lax.cond(
251-
fetch_step < num_steps,
272+
lax.bitwise_and(fetch_step >= delay_release, fetch_step < num_steps),
252273
lambda: map(
253274
lambda bref: bref.copy_in(fetch_slot, fetch_indices, barrier_ref),
254275
in_brefs,

0 commit comments

Comments
 (0)