11# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
22# with fixes for triton 2.3
3- # forward is modified to return unnormalized accumulation, row maxes, row lse - reduced over passed rings
4- # both forwards and backwards is modified to allow for masking out the diagonal for striped ring attention
53
64from functools import partial
75import math
@@ -24,6 +22,8 @@ def round_up_multiple(n, mult):
2422def is_contiguous (x : Tensor ):
2523 return x .stride (- 1 ) == 1
2624
25+ TRITON_BLOCK_SIZE = 128 # some block size that allows triton not to break, at least half a year ago
26+
2727INSTALL_COMMAND = 'pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly'
2828
2929# make sure triton 2.1+ is installed
@@ -45,10 +45,6 @@ def is_contiguous(x: Tensor):
4545import triton .language as tl
4646from triton .language .extra import libdevice
4747
48- # constants
49-
50- TRITON_BLOCK_SIZE = 128
51-
5248# kernels
5349
5450@triton .heuristics (
@@ -784,8 +780,8 @@ def flash_attn_backward(
784780 assert all ([is_contiguous (t ) for t in (q , k , v , o , dq , dk , dv )])
785781
786782 softmax_scale = dim ** - 0.5
787- # dq_accum = torch.zeros_like(q, dtype=torch.float32)
788- dq_accum = torch .empty_like (q , dtype = torch .float32 )
783+
784+ dq_accum = torch .empty_like (q , dtype = torch .float32 )
789785
790786 # delta = torch.zeros_like(lse)
791787
@@ -812,9 +808,6 @@ def flash_attn_backward(
812808 BLOCK_HEADDIM = BLOCK_HEADDIM ,
813809 )
814810
815- # BLOCK_M = 128
816- # BLOCK_N = 64
817- # num_warps = 4
818811 grid = lambda META : (
819812 triton .cdiv (seqlen_k , META ["BLOCK" ]) if META ["SEQUENCE_PARALLEL" ] else 1 ,
820813 batch * nheads ,
@@ -887,6 +880,7 @@ def forward(
887880 fq , fk , fv ,
888881 block_size ,
889882 selected_block_indices ,
883+ fmask ,
890884 num_grouped_queries
891885 ):
892886 fq , fk , fv = tuple (rearrange (t , 'b h n d -> b n h d' ) for t in (fq , fk , fv ))
@@ -897,7 +891,7 @@ def forward(
897891
898892 out , lse = flash_attn_forward (fq , fk , fv , block_size = block_size )
899893
900- ctx .save_for_backward (fq , fk , fv , out , lse )
894+ ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
901895 ctx ._saved_variables = (block_size ,)
902896
903897 out = rearrange (out , 'b n h d -> b h n d' )
@@ -907,7 +901,7 @@ def forward(
907901 def backward (self , ctx , do ):
908902 do = rearrange (do , 'b h n d -> b n h d' )
909903
910- q , k , v , out , lse = ctx .saved_tensors
904+ q , k , v , kv_indices , mask , out , lse = ctx .saved_tensors
911905
912906 (
913907 block_size ,
@@ -921,6 +915,6 @@ def backward(self, ctx, do):
921915 flash_attn_backward (do , q , k , v , out , lse , dq , dk , dv , block_size = block_size )
922916
923917 dq , dk , dv = tuple (rearrange (t , 'b n h d -> b h n d' ) for t in (dq , dk , dv ))
924- return dq , dk , dv , None , None , None
918+ return dq , dk , dv , None , None , None , None
925919
926920native_sparse_attend = NSA .apply
0 commit comments