Skip to content

Commit c76e5fe

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] copy_smem_to_gmem now supports wait_read_only
PiperOrigin-RevId: 698343812
1 parent 14da7eb commit c76e5fe

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def loop_body(step, carry):
207207
# Wait for the current GMEM->SMEM copy to complete.
208208
gpu_primitives.barrier_wait(barrier_ref.at[slot])
209209
# Wait for the previous output SMEM->GMEM copy to complete.
210-
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
210+
gpu_primitives.wait_smem_to_gmem(
211+
max_concurrent_steps - 1, wait_read_only=True
212+
)
211213

212214
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
213215
body(

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -363,20 +363,30 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None:
363363

364364

365365
@wait_smem_to_gmem_p.def_effectful_abstract_eval
366-
def _wait_smem_to_gmem_abstract_eval(n):
367-
del n # Unused.
366+
def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only):
367+
del n, wait_read_only # Unused.
368368
return (), {gpu_core._memory_effect}
369369

370370

371371
@lowering.register_lowering_rule(wait_smem_to_gmem_p)
372-
def _wait_smem_to_gmem_lowering(ctx: lowering.LoweringRuleContext, n):
373-
ctx.launch_ctx.await_async_copy(allow_groups=n)
372+
def _wait_smem_to_gmem_lowering(
373+
ctx: lowering.LoweringRuleContext, n, *, wait_read_only
374+
):
375+
ctx.launch_ctx.await_async_copy(
376+
allow_groups=n, await_read_only=wait_read_only
377+
)
374378
return ()
375379

376380

377-
def wait_smem_to_gmem(n: int) -> None:
378-
"""Waits until there are no more than ``n`` SMEM->GMEM copies in flight."""
379-
wait_smem_to_gmem_p.bind(n)
381+
def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None:
382+
"""Waits until there are no more than ``n`` SMEM->GMEM copies in flight.
383+
384+
Args:
385+
n: The maximum number of copies in flight to wait for.
386+
wait_read_only: If ``True``, wait for the in flight copies to finish
387+
reading from SMEM. The writes to GMEM are not waited for.
388+
"""
389+
wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only)
380390

381391

382392
# WGMMA on an accumulator reference

0 commit comments

Comments
 (0)