@@ -74,7 +74,7 @@ def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
7474 wg_idx = lax .axis_index ("wg" )
7575 qo_smem2 , k_smem , v_smem = smem_buffers
7676 k_barriers , v_barriers , q_barriers = buffer_barriers
77- k_consumed_barrier , v_consumed_barrier = consumed_barriers
77+ k_consumed_barriers , v_consumed_barriers = consumed_barriers
7878 def perform_schedule_barrier ():
7979 plgpu .barrier_arrive (schedule_barrier )
8080 plgpu .barrier_wait (schedule_barrier )
@@ -116,7 +116,7 @@ def compute_qk(acc_ref):
116116 perform_schedule_barrier ()
117117 return acc_ref [...]
118118 qk = pl .run_scoped (compute_qk , plgpu .ACC ((block_q , block_kv ), jnp .float32 ))
119- plgpu .barrier_arrive (k_consumed_barrier )
119+ plgpu .barrier_arrive (k_consumed_barriers . at [ slot ] )
120120
121121 # Softmax
122122 # We keep m scaled by log2e to use FMA instructions when computing p.
@@ -153,7 +153,7 @@ def compute_pv(acc_ref):
153153 def _wait ():
154154 plgpu .barrier_wait (k_barriers .at [wait_slot ])
155155 acc = pl .run_state (compute_pv )(plgpu .ACC .init (acc ))
156- plgpu .barrier_arrive (v_consumed_barrier )
156+ plgpu .barrier_arrive (v_consumed_barriers . at [ slot ] )
157157 return acc , m_i , l_i
158158 if kv_seq_len % block_kv :
159159 raise ValueError (f"{ kv_seq_len = } must be a multiple of { block_kv = } " )
@@ -184,17 +184,12 @@ def kv_loop(kv_step, _):
184184 tma_step = kv_step + max_concurrent_steps
185185 tma_slot = lax .rem (kv_step , max_concurrent_steps )
186186 s = (batch , pl .ds (tma_step * block_kv , block_kv ), kv_head )
187- plgpu .barrier_wait (k_consumed_barrier )
187+ plgpu .barrier_wait (k_consumed_barriers . at [ tma_slot ] )
188188 plgpu .copy_gmem_to_smem (k_ref .at [s ], k_smem .at [tma_slot ], k_barriers .at [tma_slot ])
189- plgpu .barrier_wait (v_consumed_barrier )
189+ plgpu .barrier_wait (v_consumed_barriers . at [ tma_slot ] )
190190 plgpu .copy_gmem_to_smem (v_ref .at [s ], v_smem .at [tma_slot ], v_barriers .at [tma_slot ])
191191 lax .fori_loop (0 , kv_seq_len // block_kv - max_concurrent_steps , kv_loop , None )
192192
193- def kv_epilogue (i , _ ):
194- plgpu .barrier_wait (k_consumed_barrier )
195- plgpu .barrier_wait (v_consumed_barrier )
196- lax .fori_loop (0 , max_concurrent_steps , kv_epilogue , None )
197-
198193 def run (refs ):
199194 q_ref , k_ref , v_ref , out_ref = refs
200195
@@ -210,7 +205,6 @@ def run(refs):
210205 @pl .core_map (mesh )
211206 def _kernel_entry ():
212207 compute_wgs = 2
213- barrier_2wg = plgpu .Barrier (num_arrivals = compute_wgs )
214208 tiling = plgpu .TilingTransform ((64 , 64 ))
215209 swizzle = plgpu .SwizzleTransform (128 )
216210 qo_scratch = plgpu .SMEM (
@@ -233,8 +227,8 @@ def _kernel_entry():
233227 plgpu .Barrier (1 , num_barriers = max_concurrent_steps ),
234228 plgpu .Barrier (1 , num_barriers = compute_wgs ),
235229 ),
236- (barrier_2wg , barrier_2wg ) ,
237- barrier_2wg ,
230+ (plgpu . Barrier ( num_arrivals = compute_wgs , num_barriers = max_concurrent_steps ),) * 2 ,
231+ plgpu . Barrier ( num_arrivals = compute_wgs ) ,
238232 )
239233
240234 _ , _ , _ , out = pl .run_state (run )((q , k , v , jnp .full_like (q , jnp .inf )))
0 commit comments