1+ from __future__ import annotations
2+
13# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
24# with fixes for triton 2.3
35
79
810import torch
911from torch import Tensor
12+ import torch .nn .functional as F
1013
1114from einops import repeat , rearrange , reduce
1215
@@ -22,6 +25,17 @@ def divisible_by(num, den):
2225def round_up_multiple (n , mult ):
2326 return ceil (n / mult ) * mult
2427
28+ def pad_at_dim (t , pad : tuple [int , int ], * , dim = - 1 , value = 0. ):
29+ dims_from_right = (- dim - 1 ) if dim < 0 else (t .ndim - dim - 1 )
30+ zeros = ((0 , 0 ) * dims_from_right )
31+ return F .pad (t , (* zeros , * pad ), value = value )
32+
33+ def pad_to_multiple (t , mult , * , dim ):
34+ length = t .shape [dim ]
35+ padded_length = round_up_multiple (length , mult )
36+ remainder = padded_length - length
37+ return pad_at_dim (t , (0 , remainder ), dim = dim )
38+
2539def is_contiguous (x : Tensor ):
2640 return x .stride (- 1 ) == 1
2741
@@ -168,13 +182,13 @@ def forward_kernel(
168182 if EVEN_HEADDIM :
169183 q = tl .load (q_ptrs )
170184 else :
171- q = tl .load (q_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
185+ q = tl .load (q_ptrs , mask = offs_d [None , None , :] < headdim , other = 0.0 )
172186 else :
173187 if EVEN_HEADDIM :
174- q = tl .load (q_ptrs , mask = offs_m [:, None ] < seqlen_q , other = 0.0 )
188+ q = tl .load (q_ptrs , mask = offs_m [None , :, None ] < seqlen_q , other = 0.0 )
175189 else :
176190 q = tl .load (
177- q_ptrs , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ), other = 0.0
191+ q_ptrs , mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ), other = 0.0
178192 )
179193
180194 q = q .reshape ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ])
@@ -360,13 +374,13 @@ def forward_kernel(
360374 if EVEN_HEADDIM :
361375 tl .store (out_ptrs , acc_o )
362376 else :
363- tl .store (out_ptrs , acc_o , mask = offs_d [None , :] < headdim )
377+ tl .store (out_ptrs , acc_o , mask = offs_d [None , None , :] < headdim )
364378 else :
365379 if EVEN_HEADDIM :
366- tl .store (out_ptrs , acc_o , mask = offs_m [:, None ] < seqlen_q )
380+ tl .store (out_ptrs , acc_o , mask = offs_m [None , :, None ] < seqlen_q )
367381 else :
368382 tl .store (
369- out_ptrs , acc_o , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim )
383+ out_ptrs , acc_o , mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim )
370384 )
371385
372386def native_sparse_attn_forward (
@@ -672,11 +686,11 @@ def backward_kernel_one_col_block(
672686 q = tl .load (q_ptrs )
673687 else :
674688 if EVEN_HEADDIM :
675- q = tl .load (q_ptrs , mask = offs_m [:, None ] < seqlen_q , other = 0.0 )
689+ q = tl .load (q_ptrs , mask = offs_m [None , :, None ] < seqlen_q , other = 0.0 )
676690 else :
677691 q = tl .load (
678692 q_ptrs ,
679- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
693+ mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
680694 other = 0.0 ,
681695 )
682696 # recompute p = softmax(qk, dim=-1).T
@@ -717,7 +731,7 @@ def backward_kernel_one_col_block(
717731 # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
718732 do = tl .load (
719733 do_ptrs ,
720- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
734+ mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
721735 other = 0.0 ,
722736 )
723737
@@ -887,12 +901,12 @@ def backward_kernel_one_col_block(
887901 tl .atomic_add (dq_ptrs , dq , sem = 'relaxed' )
888902 else :
889903 if EVEN_HEADDIM :
890- tl .atomic_add (dq_ptrs , dq , mask = offs_m [:, None ] < seqlen_q , sem = 'relaxed' )
904+ tl .atomic_add (dq_ptrs , dq , mask = offs_m [None , :, None ] < seqlen_q , sem = 'relaxed' )
891905 else :
892906 tl .atomic_add (
893907 dq_ptrs ,
894908 dq ,
895- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
909+ mask = (offs_m [None , :, None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
896910 sem = 'relaxed' ,
897911 )
898912
@@ -1248,7 +1262,7 @@ def native_sparse_attend(
12481262 fmask ,
12491263 return_lse = False
12501264):
1251- assert divisible_by ( fq .shape [- 2 ], block_size )
1265+ seq_len = fq .shape [- 2 ]
12521266
12531267 out , lse = _native_sparse_attend (
12541268 fq , fk , fv ,
@@ -1260,4 +1274,4 @@ def native_sparse_attend(
12601274 if not return_lse :
12611275 return out
12621276
1263- return out , lse
1277+ return out , lse [..., : seq_len ]
0 commit comments