@@ -95,6 +95,7 @@ def forward_kernel(
9595 EVEN_HEADDIM : tl .constexpr ,
9696 BLOCK : tl .constexpr ,
9797 QUERY_HEAD_GROUPS : tl .constexpr ,
98+ QUERY_EXPAND_DIM : tl .constexpr ,
9899 NUM_SEL_KV_BLOCKS : tl .constexpr
99100):
100101 start_m = tl .program_id (0 )
@@ -261,6 +262,12 @@ def forward_kernel(
261262 offs_m * stride_kvbl_m
262263 )
263264
265+ q = q .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
266+ q = q .permute ((1 , 0 , 2 ))
267+ q = tl .expand_dims (q , 2 )
268+ q = tl .broadcast_to (q , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
269+ q = q .reshape (BLOCK , 16 , BLOCK_HEADDIM )
270+
264271 for off_sel_kv_block in range (NUM_SEL_KV_BLOCKS ):
265272 block_indices = tl .load (kv_block_indices_ptrs + off_sel_kv_block )
266273 block_masks = tl .load (kv_block_mask_ptrs + off_sel_kv_block )
@@ -282,18 +289,21 @@ def forward_kernel(
282289 # similarities
283290
284291 block_qk = tl .zeros ([BLOCK , 16 , BLOCK ], dtype = tl .float32 )
285- qk = tl .zeros ([BLOCK , BLOCK ], dtype = tl .float32 )
292+ qk = tl .zeros ([QUERY_HEAD_GROUPS , BLOCK , BLOCK ], dtype = tl .float32 )
286293
287- k_block = tl .reshape (k_block , ( BLOCK , BLOCK , BLOCK_HEADDIM ) )
288- k_block = tl .permute (k_block , ( 0 , 2 , 1 ) )
294+ k_block = k_block .reshape (BLOCK , BLOCK , BLOCK_HEADDIM )
295+ k_block = k_block .permute (0 , 2 , 1 )
289296
290- q_expanded = tl .expand_dims (q , 1 )
291- q_expanded = tl .broadcast_to (q_expanded , (BLOCK , 16 , BLOCK_HEADDIM ))
297+ block_qk = tl .dot (q , k_block )
298+ block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
299+ block_qk = tl .sum (block_qk , 2 ) / QUERY_EXPAND_DIM
300+ block_qk = block_qk .permute (1 , 0 , 2 )
292301
293- block_qk = tl .dot (q_expanded , k_block )
294- qk += tl .sum (block_qk , 1 ) / 16.
302+ qk += block_qk
295303 qk += tl .where (block_masks [:, None ], 0 , float ("-inf" ))
296304
305+ qk = qk .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
306+
297307 # attention
298308
299309 m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
@@ -312,11 +322,18 @@ def forward_kernel(
312322 v_block = tl .reshape (v_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
313323
314324 p = p .to (v_block .dtype )
315- p_expanded = tl .expand_dims (p , 1 )
316- p_expanded = tl .broadcast_to (p_expanded , (BLOCK , 16 , BLOCK ))
325+ p_expanded = p .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
326+ p_expanded = p_expanded .permute (1 , 0 , 2 )
327+ p_expanded = tl .expand_dims (p_expanded , 2 )
328+ p_expanded = tl .broadcast_to (p_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
329+ p_expanded = p_expanded .reshape (BLOCK , 16 , BLOCK )
317330
318331 block_acc_o = tl .dot (p_expanded , v_block )
319- block_acc_o = tl .sum (block_acc_o , 1 ) / 16.
332+ block_acc_o = block_acc_o .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
333+ block_acc_o = tl .sum (block_acc_o , 2 ) / QUERY_EXPAND_DIM
334+ block_acc_o = block_acc_o .permute (1 , 0 , 2 )
335+ block_acc_o = block_acc_o .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM )
336+
320337 acc_o += block_acc_o
321338
322339 # -- update statistics
@@ -352,7 +369,7 @@ def forward_kernel(
352369 out_ptrs , acc_o , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim )
353370 )
354371
355- def flash_attn_forward (
372+ def native_sparse_attn_forward (
356373 q ,
357374 k ,
358375 v ,
@@ -424,6 +441,7 @@ def flash_attn_forward(
424441 BLOCK_HEADDIM ,
425442 BLOCK = block_size ,
426443 QUERY_HEAD_GROUPS = head_groups ,
444+ QUERY_EXPAND_DIM = 16 // head_groups ,
427445 NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
428446 num_warps = num_warps ,
429447 num_stages = 1 ,
@@ -978,7 +996,7 @@ def backward_kernel(
978996 NUM_SEL_KV_BLOCKS = NUM_SEL_KV_BLOCKS
979997 )
980998
981- def flash_attn_backward (
999+ def native_sparse_attn_backward (
9821000 do ,
9831001 q , k , v ,
9841002 kv_block_indices ,
@@ -1128,7 +1146,7 @@ def forward(
11281146
11291147 fq , fk , fv = tuple (t .half () for t in (fq , fk , fv ))
11301148
1131- out , lse = flash_attn_forward (
1149+ out , lse = native_sparse_attn_forward (
11321150 fq , fk , fv ,
11331151 selected_block_indices ,
11341152 fmask ,
@@ -1162,7 +1180,7 @@ def backward(self, ctx, do):
11621180 dk = torch .zeros (k .shape , dtype = torch .float32 , device = device )
11631181 dv = torch .zeros (v .shape , dtype = torch .float32 , device = device )
11641182
1165- flash_attn_backward (
1183+ native_sparse_attn_backward (
11661184 do , q , k , v ,
11671185 sel_block_indices , mask ,
11681186 out , lse , dq , dk , dv ,
0 commit comments