@@ -129,8 +129,8 @@ def forward_kernel(
129129 q_ptrs = (
130130 Q +
131131 off_b * stride_qb +
132- offs_qh [:, None , None ] * stride_qh +
133- offs_m [None , : , None ] * stride_qm +
132+ offs_qh [None , : , None ] * stride_qh +
133+ offs_m [:, None , None ] * stride_qm +
134134 offs_d [None , None , :]
135135 )
136136
@@ -152,48 +152,56 @@ def forward_kernel(
152152
153153 # maximum
154154
155- m_i = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
155+ m_i = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
156156
157157 # lse
158158
159159 lse_ptrs = (
160160 Lse +
161161 off_b * stride_lse_b +
162- offs_qh [:, None ] * seqlen_q_rounded +
163- offs_m [None , : ]
162+ offs_qh [None , : ] * seqlen_q_rounded +
163+ offs_m [:, None ]
164164 )
165165
166- lse_i = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
166+ lse_i = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
167167
168168 # output
169169
170170 out_ptrs = (
171171 Out +
172172 off_b * stride_ob +
173- offs_qh [:, None , None ] * stride_oh +
174- offs_m [None , : , None ] * stride_om +
173+ offs_qh [None , : , None ] * stride_oh +
174+ offs_m [:, None , None ] * stride_om +
175175 offs_d [None , None , :]
176176 )
177177
178- acc_o = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
178+ acc_o = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM ], dtype = tl .float32 )
179179
180180 # load queries, keys, values
181181
182182 if EVEN_M & EVEN_N :
183183 if EVEN_HEADDIM :
184184 q = tl .load (q_ptrs )
185185 else :
186- q = tl .load (q_ptrs , mask = offs_d [None , None , :] < headdim , other = 0.0 )
186+ q = tl .load (
187+ q_ptrs ,
188+ mask = offs_d [None , None , :] < headdim ,
189+ other = 0.0
190+ )
187191 else :
188192 if EVEN_HEADDIM :
189- q = tl .load (q_ptrs , mask = offs_m [None , :, None ] < seqlen_q , other = 0.0 )
193+ q = tl .load (
194+ q_ptrs ,
195+ mask = offs_m [:, None , None ] < seqlen_q ,
196+ other = 0.0
197+ )
190198 else :
191199 q = tl .load (
192- q_ptrs , mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ), other = 0.0
200+ q_ptrs ,
201+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
202+ other = 0.0
193203 )
194204
195- q = q .reshape ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ])
196-
197205 if EVEN_N & EVEN_M :
198206 if EVEN_HEADDIM :
199207 k = tl .load (k_ptrs )
@@ -203,65 +211,75 @@ def forward_kernel(
203211 if EVEN_HEADDIM :
204212 k = tl .load (
205213 k_ptrs ,
206- mask = offs_n [:, None ] < seqlen_k ,
207- other = 0.0 ,
214+ mask = offs_n [:, None ] < seqlen_k ,
215+ other = 0.0 ,
208216 )
209217 else :
210218 k = tl .load (
211219 k_ptrs ,
212- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
213- other = 0.0 ,
220+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
221+ other = 0.0 ,
214222 )
215223
216- qk = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK ], dtype = tl .float32 )
224+ qk = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
225+
226+ q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
227+
217228 qk += tl .dot (q , tl .trans (k ))
218229
230+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
231+
219232 if not EVEN_N :
220233 qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
221234
222- qk = qk .reshape ([QUERY_HEAD_GROUPS , BLOCK , BLOCK ])
223-
224- qk += tl .where (offs_m [:, None ] >= offs_n [None , :], 0 , float ("-inf" ))
235+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
225236
226- qk = qk . reshape ([ QUERY_HEAD_GROUPS * BLOCK , BLOCK ] )
237+ qk += tl . where ( offs_m [:, None , None ] >= offs_n [ None , None , :], 0 , float ( "-inf" ) )
227238
228- m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
229- p = tl .exp (qk * softmax_scale - m_ij [:, None ])
239+ m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
240+ p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
230241
231- l_ij = tl .sum (p , 1 )
242+ l_ij = tl .sum (p , 2 )
232243
233244 acc_o_scale = tl .exp (m_i - m_ij )
234- acc_o *= acc_o_scale [:, None ]
245+ acc_o *= acc_o_scale [:, :, None ]
235246
236247 if EVEN_N & EVEN_M :
237248 if EVEN_HEADDIM :
238249 v = tl .load (v_ptrs )
239250 else :
240- v = tl .load (v_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
251+ v = tl .load (
252+ v_ptrs ,
253+ mask = offs_d [None , :] < headdim ,
254+ other = 0.0
255+ )
241256 else :
242257 if EVEN_HEADDIM :
243258 v = tl .load (
244259 v_ptrs ,
245- mask = offs_n [:, None ] < seqlen_k ,
246- other = 0.0 ,
260+ mask = offs_n [:, None ] < seqlen_k ,
261+ other = 0.0 ,
247262 )
248263 else :
249264 v = tl .load (
250265 v_ptrs ,
251- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
252- other = 0.0 ,
266+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
267+ other = 0.0 ,
253268 )
254269
255- p = p .to (v .dtype )
256- acc_o += tl .dot (p , v )
270+ p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
271+
272+ causal_o = tl .dot (p , v )
273+
274+ acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
257275
258276 # -- update statistics
259277
260278 m_i = m_ij
261279 l_i_new = tl .exp (lse_i - m_ij ) + l_ij
262280 lse_i = m_ij + tl .log (l_i_new )
263281
264- # take care of the selected kv blocks
282+ # # take care of the selected kv blocks
265283
266284 kv_block_indices_ptrs = (
267285 kv_block_indices +
@@ -277,8 +295,7 @@ def forward_kernel(
277295 offs_m * stride_kvbl_m
278296 )
279297
280- q = q .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
281- q = q .permute ((1 , 0 , 2 ))
298+ q = q .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
282299 q = tl .expand_dims (q , 2 )
283300 q = tl .broadcast_to (q , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
284301 q = q .reshape (BLOCK , 16 , BLOCK_HEADDIM )
@@ -290,11 +307,19 @@ def forward_kernel(
290307 blocks_offs_n = block_indices [:, None ] * BLOCK + tl .arange (0 , BLOCK )[None , :]
291308
292309 block_k_ptrs = (
293- K + off_b * stride_kb + off_h * stride_kh + (blocks_offs_n [:, :, None ] * stride_kn + offs_d [None , None , :])
310+ K +
311+ off_b * stride_kb +
312+ off_h * stride_kh +
313+ blocks_offs_n [:, :, None ] * stride_kn +
314+ offs_d [None , None , :]
294315 )
295316
296317 block_v_ptrs = (
297- V + off_b * stride_vb + off_h * stride_vh + (blocks_offs_n [:, :, None ] * stride_vn + offs_d [None , None , :])
318+ V +
319+ off_b * stride_vb +
320+ off_h * stride_vh +
321+ blocks_offs_n [:, :, None ] * stride_vn +
322+ offs_d [None , None , :]
298323 )
299324
300325 # load k of shape (m, n, d), sparsely selected by each query
@@ -304,50 +329,44 @@ def forward_kernel(
304329 # similarities
305330
306331 block_qk = tl .zeros ([BLOCK , 16 , BLOCK ], dtype = tl .float32 )
307- qk = tl .zeros ([QUERY_HEAD_GROUPS , BLOCK , BLOCK ], dtype = tl .float32 )
332+ qk = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
308333
309334 k_block = k_block .reshape (BLOCK , BLOCK , BLOCK_HEADDIM )
310335 k_block = k_block .permute (0 , 2 , 1 )
311336
312- block_qk = tl .dot (q , k_block )
337+ block_qk + = tl .dot (q , k_block )
313338 block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
314339 block_qk = tl .sum (block_qk , 2 ) / QUERY_EXPAND_DIM
315- block_qk = block_qk .permute (1 , 0 , 2 )
316340
317341 qk += block_qk
318- qk += tl .where (block_masks [:, None ], 0 , float ("-inf" ))
319-
320- qk = qk .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
342+ qk += tl .where (block_masks [:, None , None ], 0 , float ("-inf" ))
321343
322344 # attention
323345
324- m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
325- p = tl .exp (qk * softmax_scale - m_ij [:, None ])
346+ m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
347+ block_p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
326348
327- l_ij = tl .sum (p , 1 )
349+ l_ij = tl .sum (block_p , 2 )
328350
329351 # renormalize the running output
330352
331353 acc_o_scale = tl .exp (m_i - m_ij )
332- acc_o = acc_o * acc_o_scale [:, None ]
354+ acc_o = acc_o * acc_o_scale [:, :, None ]
333355
334356 # aggregate values
335357
336358 v_block = tl .load (block_v_ptrs )
337359 v_block = tl .reshape (v_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
338360
339- p = p .to (v_block .dtype )
340- p_expanded = p .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
341- p_expanded = p_expanded .permute (1 , 0 , 2 )
361+ block_p = block_p .to (v_block .dtype )
362+ p_expanded = block_p .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
342363 p_expanded = tl .expand_dims (p_expanded , 2 )
343364 p_expanded = tl .broadcast_to (p_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
344365 p_expanded = p_expanded .reshape (BLOCK , 16 , BLOCK )
345366
346367 block_acc_o = tl .dot (p_expanded , v_block )
347368 block_acc_o = block_acc_o .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
348369 block_acc_o = tl .sum (block_acc_o , 2 ) / QUERY_EXPAND_DIM
349- block_acc_o = block_acc_o .permute (1 , 0 , 2 )
350- block_acc_o = block_acc_o .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM )
351370
352371 acc_o += block_acc_o
353372
@@ -360,28 +379,38 @@ def forward_kernel(
360379 # normalize accumulated out
361380
362381 acc_o_scale = tl .exp (m_i - lse_i )
363- acc_o *= acc_o_scale [:, None ]
382+ acc_o *= acc_o_scale [:, :, None ]
364383
365384 # write back lse
366385
367- lse_i = lse_i .reshape ([ QUERY_HEAD_GROUPS , BLOCK ] )
368- tl .store (lse_ptrs , lse_i , mask = offs_m [None , : ] < seqlen_q )
386+ lse_i = lse_i .reshape (BLOCK , QUERY_HEAD_GROUPS )
387+ tl .store (lse_ptrs , lse_i , mask = offs_m [:, None ] < seqlen_q )
369388
370389 # write to output
371390
372- acc_o = acc_o .reshape ([ QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM ] )
391+ acc_o = acc_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
373392
374393 if EVEN_M :
375394 if EVEN_HEADDIM :
376395 tl .store (out_ptrs , acc_o )
377396 else :
378- tl .store (out_ptrs , acc_o , mask = offs_d [None , None , :] < headdim )
397+ tl .store (
398+ out_ptrs ,
399+ acc_o ,
400+ mask = offs_d [None , None , :] < headdim
401+ )
379402 else :
380403 if EVEN_HEADDIM :
381- tl .store (out_ptrs , acc_o , mask = offs_m [None , :, None ] < seqlen_q )
404+ tl .store (
405+ out_ptrs ,
406+ acc_o ,
407+ mask = offs_m [:, None , None ] < seqlen_q
408+ )
382409 else :
383410 tl .store (
384- out_ptrs , acc_o , mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim )
411+ out_ptrs ,
412+ acc_o ,
413+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim )
385414 )
386415
387416def native_sparse_attn_forward (
0 commit comments