Skip to content

Commit 8b65620

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas MGPU] Use multiple k/v_consumed_barriers in the attention kernel
There's nothing technically preventing the compute threads from running ahead and signalling the consumption of k/v twice in case the memory thread ends up being temporarily starved. I don't think this was ever a problem in practice since the GPU hardware scheduler is surprisingly fair, but it's good not to have races :) PiperOrigin-RevId: 703520322
1 parent 08d31d0 commit 8b65620

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

jax/experimental/pallas/ops/gpu/attention_mgpu.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
7474
wg_idx = lax.axis_index("wg")
7575
qo_smem2, k_smem, v_smem = smem_buffers
7676
k_barriers, v_barriers, q_barriers = buffer_barriers
77-
k_consumed_barrier, v_consumed_barrier = consumed_barriers
77+
k_consumed_barriers, v_consumed_barriers = consumed_barriers
7878
def perform_schedule_barrier():
7979
plgpu.barrier_arrive(schedule_barrier)
8080
plgpu.barrier_wait(schedule_barrier)
@@ -116,7 +116,7 @@ def compute_qk(acc_ref):
116116
perform_schedule_barrier()
117117
return acc_ref[...]
118118
qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32))
119-
plgpu.barrier_arrive(k_consumed_barrier)
119+
plgpu.barrier_arrive(k_consumed_barriers.at[slot])
120120

121121
# Softmax
122122
# We keep m scaled by log2e to use FMA instructions when computing p.
@@ -153,7 +153,7 @@ def compute_pv(acc_ref):
153153
def _wait():
154154
plgpu.barrier_wait(k_barriers.at[wait_slot])
155155
acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc))
156-
plgpu.barrier_arrive(v_consumed_barrier)
156+
plgpu.barrier_arrive(v_consumed_barriers.at[slot])
157157
return acc, m_i, l_i
158158
if kv_seq_len % block_kv:
159159
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}")
@@ -184,17 +184,12 @@ def kv_loop(kv_step, _):
184184
tma_step = kv_step + max_concurrent_steps
185185
tma_slot = lax.rem(kv_step, max_concurrent_steps)
186186
s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head)
187-
plgpu.barrier_wait(k_consumed_barrier)
187+
plgpu.barrier_wait(k_consumed_barriers.at[tma_slot])
188188
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot])
189-
plgpu.barrier_wait(v_consumed_barrier)
189+
plgpu.barrier_wait(v_consumed_barriers.at[tma_slot])
190190
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot])
191191
lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None)
192192

193-
def kv_epilogue(i, _):
194-
plgpu.barrier_wait(k_consumed_barrier)
195-
plgpu.barrier_wait(v_consumed_barrier)
196-
lax.fori_loop(0, max_concurrent_steps, kv_epilogue, None)
197-
198193
def run(refs):
199194
q_ref, k_ref, v_ref, out_ref = refs
200195

@@ -210,7 +205,6 @@ def run(refs):
210205
@pl.core_map(mesh)
211206
def _kernel_entry():
212207
compute_wgs = 2
213-
barrier_2wg = plgpu.Barrier(num_arrivals=compute_wgs)
214208
tiling = plgpu.TilingTransform((64, 64))
215209
swizzle = plgpu.SwizzleTransform(128)
216210
qo_scratch = plgpu.SMEM(
@@ -233,8 +227,8 @@ def _kernel_entry():
233227
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
234228
plgpu.Barrier(1, num_barriers=compute_wgs),
235229
),
236-
(barrier_2wg, barrier_2wg),
237-
barrier_2wg,
230+
(plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2,
231+
plgpu.Barrier(num_arrivals=compute_wgs),
238232
)
239233

240234
_, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf)))

0 commit comments

Comments
 (0)