33# forward is modified to return unnormalized accumulation, row maxes, row lse - reduced over passed rings
44# both forwards and backwards is modified to allow for masking out the diagonal for striped ring attention
55
6+ from functools import partial
7+ import math
68from math import ceil
79
810import torch
@@ -82,12 +84,6 @@ def _fwd_kernel(
8284 CACHE_KEY_SEQLEN_Q ,
8385 CACHE_KEY_SEQLEN_K ,
8486 HAS_BIAS : tl .constexpr ,
85- IS_CAUSAL : tl .constexpr ,
86- CAUSAL_MASK_DIAGONAL : tl .constexpr ,
87- LOAD_ACCUMULATED : tl .constexpr ,
88- RETURN_NORMALIZED_OUTPUT : tl .constexpr ,
89- SOFTCLAMP_QK_SIM : tl .constexpr ,
90- SOFTCLAMP_VALUE : tl .constexpr ,
9187 BLOCK_HEADDIM : tl .constexpr ,
9288 EVEN_M : tl .constexpr ,
9389 EVEN_N : tl .constexpr ,
@@ -121,19 +117,13 @@ def _fwd_kernel(
121117
122118 m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
123119
124- if LOAD_ACCUMULATED :
125- m_i = tl .load (m_ptrs )
126- else :
127- m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
120+ m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
128121
129122 # load lse
130123
131124 lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
132125
133- if LOAD_ACCUMULATED :
134- lse_i = tl .load (lse_ptrs )
135- else :
136- lse_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
126+ lse_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
137127
138128 # load accumualted output
139129
@@ -146,23 +136,7 @@ def _fwd_kernel(
146136 + (offs_m [:, None ] * stride_om + offs_d [None , :])
147137 )
148138
149- if LOAD_ACCUMULATED :
150- if EVEN_M :
151- if EVEN_HEADDIM :
152- acc_o = tl .load (out_ptrs )
153- else :
154- acc_o = tl .load (out_ptrs , mask = offs_d [None , :] < headdim )
155- else :
156- if EVEN_HEADDIM :
157- acc_o = tl .load (out_ptrs , mask = offs_m [:, None ] < seqlen_q )
158- else :
159- acc_o = tl .load (
160- out_ptrs , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim )
161- )
162-
163- acc_o = acc_o .to (tl .float32 )
164- else :
165- acc_o = tl .zeros ([BLOCK_M , BLOCK_HEADDIM ], dtype = tl .float32 )
139+ acc_o = tl .zeros ([BLOCK_M , BLOCK_HEADDIM ], dtype = tl .float32 )
166140
167141 # load queries, keys, values
168142
@@ -179,7 +153,7 @@ def _fwd_kernel(
179153 q_ptrs , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ), other = 0.0
180154 )
181155
182- end_n = seqlen_k if not IS_CAUSAL else tl .minimum ((start_m + 1 ) * BLOCK_M , seqlen_k )
156+ end_n = tl .minimum ((start_m + 1 ) * BLOCK_M , seqlen_k )
183157 for start_n in range (0 , end_n , BLOCK_N ):
184158 start_n = tl .multiple_of (start_n , BLOCK_N )
185159
@@ -204,21 +178,10 @@ def _fwd_kernel(
204178 qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
205179 qk += tl .dot (q , tl .trans (k ))
206180
207- if SOFTCLAMP_QK_SIM :
208- effective_softclamp_value = SOFTCLAMP_VALUE / softmax_scale
209- qk /= effective_softclamp_value
210- qk = libdevice .tanh (qk )
211- qk *= effective_softclamp_value
212-
213181 if not EVEN_N :
214182 qk += tl .where ((start_n + offs_n )[None , :] < seqlen_k , 0 , float ("-inf" ))
215183
216- if IS_CAUSAL :
217- if CAUSAL_MASK_DIAGONAL :
218- # needed for stripe attention
219- qk += tl .where (offs_m [:, None ] > (start_n + offs_n )[None , :], 0 , float ("-inf" ))
220- else :
221- qk += tl .where (offs_m [:, None ] >= (start_n + offs_n )[None , :], 0 , float ("-inf" ))
184+ qk += tl .where (offs_m [:, None ] >= (start_n + offs_n )[None , :], 0 , float ("-inf" ))
222185
223186 if HAS_BIAS :
224187 if EVEN_N :
@@ -270,9 +233,8 @@ def _fwd_kernel(
270233 l_i_new = tl .exp (lse_i - m_ij ) + l_ij
271234 lse_i = m_ij + tl .log (l_i_new )
272235
273- if RETURN_NORMALIZED_OUTPUT :
274- acc_o_scale = tl .exp (m_i - lse_i )
275- acc_o = acc_o * acc_o_scale [:, None ]
236+ acc_o_scale = tl .exp (m_i - lse_i )
237+ acc_o = acc_o * acc_o_scale [:, None ]
276238
277239 # offsets for m and lse
278240
@@ -283,9 +245,6 @@ def _fwd_kernel(
283245
284246 tl .store (lse_ptrs , lse_i )
285247
286- if not RETURN_NORMALIZED_OUTPUT :
287- tl .store (m_ptrs , m_i )
288-
289248 # write to output
290249
291250 if EVEN_M :
@@ -306,27 +265,14 @@ def flash_attn_forward(
306265 k ,
307266 v ,
308267 bias = None ,
309- causal = False ,
310268 o = None ,
311269 m = None ,
312270 lse = None ,
313271 softmax_scale = None ,
314- causal_mask_diagonal = False ,
315- return_normalized_output = False ,
316- load_accumulated = True ,
317- softclamp_qk_sim = False ,
318- softclamp_value = 50. ,
319- head_first_dim = False ,
320272 remove_padding = False
321273):
322274 q , k , v = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v )]
323275
324- if head_first_dim :
325- q , k , v = tuple (rearrange (t , 'b h n d -> b n h d' ) for t in (q , k , v ))
326-
327- if exists (o ):
328- o = rearrange (o , 'b h n d -> b n h d' )
329-
330276 batch , seqlen_q , nheads , d = q .shape
331277 _ , seqlen_k , _ , _ = k .shape
332278
@@ -360,17 +306,14 @@ def flash_attn_forward(
360306
361307 if not exists (lse ):
362308 max_neg_value = - torch .finfo (torch .float32 ).max
363- init_fn = partial (torch .full , fill_value = max_neg_value ) if load_accumulated else torch .empty
364- lse = init_fn ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
309+ lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
365310
366311 if not exists (m ):
367312 max_neg_value = - torch .finfo (torch .float32 ).max
368- init_fn = partial (torch .full , fill_value = max_neg_value ) if load_accumulated else torch .empty
369- m = init_fn ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
313+ m = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q .device , dtype = torch .float32 )
370314
371315 if not exists (o ):
372- init_fn = torch .zeros_like if load_accumulated else torch .empty_like
373- o = init_fn (q )
316+ o = torch .empty_like (q )
374317
375318 BLOCK_HEADDIM = max (triton .next_power_of_2 (d ), 16 )
376319 BLOCK = 128
@@ -407,27 +350,17 @@ def flash_attn_forward(
407350 seqlen_q // 32 ,
408351 seqlen_k // 32 ,
409352 has_bias ,
410- causal ,
411- causal_mask_diagonal ,
412- load_accumulated ,
413- return_normalized_output ,
414- softclamp_qk_sim ,
415- softclamp_value ,
416353 BLOCK_HEADDIM ,
417354 BLOCK_M = BLOCK ,
418355 BLOCK_N = BLOCK ,
419356 num_warps = num_warps ,
420357 num_stages = 1 ,
421358 )
422359
423- if head_first_dim :
424- o = rearrange (o , 'b n h d -> b h n d' )
425-
426360 if remove_padding :
427- m = m [..., :seqlen_q ]
428361 lse = lse [..., :seqlen_q ]
429362
430- return o , m , lse
363+ return o , lse
431364
432365@triton .jit
433366def _bwd_preprocess_do_o_dot (
@@ -533,10 +466,6 @@ def _bwd_kernel_one_col_block(
533466 headdim ,
534467 ATOMIC_ADD : tl .constexpr ,
535468 BIAS_TYPE : tl .constexpr ,
536- IS_CAUSAL : tl .constexpr ,
537- CAUSAL_MASK_DIAGONAL : tl .constexpr ,
538- SOFTCLAMP_QK_SIM : tl .constexpr ,
539- SOFTCLAMP_VALUE : tl .constexpr ,
540469 BLOCK_HEADDIM : tl .constexpr ,
541470 EVEN_M : tl .constexpr ,
542471 EVEN_N : tl .constexpr ,
@@ -545,7 +474,7 @@ def _bwd_kernel_one_col_block(
545474 BLOCK_N : tl .constexpr ,
546475):
547476 # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
548- begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N ) // BLOCK_M ) * BLOCK_M
477+ begin_m = ((start_n * BLOCK_N ) // BLOCK_M ) * BLOCK_M
549478 # initialize row/col offsets
550479 offs_qm = begin_m + tl .arange (0 , BLOCK_M )
551480 offs_n = start_n * BLOCK_N + tl .arange (0 , BLOCK_N )
@@ -627,22 +556,11 @@ def _bwd_kernel_one_col_block(
627556 # recompute p = softmax(qk, dim=-1).T
628557 qk = tl .dot (q , tl .trans (k ))
629558
630- if SOFTCLAMP_QK_SIM :
631- effective_softclamp_value = SOFTCLAMP_VALUE / softmax_scale
632- qk /= effective_softclamp_value
633- qk = libdevice .tanh (qk )
634- dtanh = 1. - qk * qk
635- qk *= effective_softclamp_value
636-
637559 # Trying to combine the two masks seem to make the result wrong
638560 if not EVEN_N : # Need to mask out otherwise the softmax is wrong
639561 qk = tl .where (offs_n [None , :] < seqlen_k , qk , float ("-inf" ))
640- if IS_CAUSAL :
641- if CAUSAL_MASK_DIAGONAL :
642- # needed for stripe attention
643- qk = tl .where (offs_m_curr [:, None ] > (offs_n [None , :]), qk , float ("-inf" ))
644- else :
645- qk = tl .where (offs_m_curr [:, None ] >= (offs_n [None , :]), qk , float ("-inf" ))
562+
563+ qk = tl .where (offs_m_curr [:, None ] >= (offs_n [None , :]), qk , float ("-inf" ))
646564
647565 if BIAS_TYPE != "none" :
648566 tl .debug_barrier () # Race condition otherwise
@@ -714,9 +632,6 @@ def _bwd_kernel_one_col_block(
714632 # for BLOCK_HEADDIM=128
715633 ds = (p * (dp - Di [:, None ]) * softmax_scale )
716634
717- if SOFTCLAMP_QK_SIM :
718- ds *= dtanh
719-
720635 ds = ds .to (q .dtype )
721636
722637 # compute dk = dot(ds.T, q)
@@ -823,7 +738,7 @@ def init_to_zero(name):
823738 # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
824739 # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
825740 ],
826- key = ["CACHE_KEY_SEQLEN_Q" , "CACHE_KEY_SEQLEN_K" , "BIAS_TYPE" , "IS_CAUSAL" , " BLOCK_HEADDIM" ],
741+ key = ["CACHE_KEY_SEQLEN_Q" , "CACHE_KEY_SEQLEN_K" , "BIAS_TYPE" , "BLOCK_HEADDIM" ],
827742)
828743@triton .heuristics (
829744 {
@@ -877,10 +792,6 @@ def _bwd_kernel(
877792 CACHE_KEY_SEQLEN_Q ,
878793 CACHE_KEY_SEQLEN_K ,
879794 BIAS_TYPE : tl .constexpr ,
880- IS_CAUSAL : tl .constexpr ,
881- CAUSAL_MASK_DIAGONAL : tl .constexpr ,
882- SOFTCLAMP_QK_SIM : tl .constexpr ,
883- SOFTCLAMP_VALUE : tl .constexpr ,
884795 BLOCK_HEADDIM : tl .constexpr ,
885796 SEQUENCE_PARALLEL : tl .constexpr ,
886797 EVEN_M : tl .constexpr ,
@@ -934,10 +845,6 @@ def _bwd_kernel(
934845 headdim ,
935846 ATOMIC_ADD = False ,
936847 BIAS_TYPE = BIAS_TYPE ,
937- IS_CAUSAL = IS_CAUSAL ,
938- CAUSAL_MASK_DIAGONAL = CAUSAL_MASK_DIAGONAL ,
939- SOFTCLAMP_QK_SIM = SOFTCLAMP_QK_SIM ,
940- SOFTCLAMP_VALUE = SOFTCLAMP_VALUE ,
941848 BLOCK_HEADDIM = BLOCK_HEADDIM ,
942849 EVEN_M = EVEN_M ,
943850 EVEN_N = EVEN_N ,
@@ -973,10 +880,6 @@ def _bwd_kernel(
973880 headdim ,
974881 ATOMIC_ADD = True ,
975882 BIAS_TYPE = BIAS_TYPE ,
976- IS_CAUSAL = IS_CAUSAL ,
977- CAUSAL_MASK_DIAGONAL = CAUSAL_MASK_DIAGONAL ,
978- SOFTCLAMP_QK_SIM = SOFTCLAMP_QK_SIM ,
979- SOFTCLAMP_VALUE = SOFTCLAMP_VALUE ,
980883 BLOCK_HEADDIM = BLOCK_HEADDIM ,
981884 EVEN_M = EVEN_M ,
982885 EVEN_N = EVEN_N ,
@@ -997,15 +900,12 @@ def flash_attn_backward(
997900 dv ,
998901 delta = None ,
999902 bias = None ,
1000- causal = False ,
1001- causal_mask_diagonal = False ,
1002903 softmax_scale = None ,
1003- softclamp_qk_sim = False ,
1004- softclamp_value = 50.
1005904):
1006905 # Make sure that the last dimension is contiguous
1007906 if do .stride (- 1 ) != 1 :
1008907 do = do .contiguous ()
908+
1009909 batch , seqlen_q , nheads , d = q .shape
1010910 _ , seqlen_k , _ , _ = k .shape
1011911 # assert d in {16, 32, 64, 128}
@@ -1113,10 +1013,6 @@ def flash_attn_backward(
11131013 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
11141014 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
11151015 bias_type ,
1116- causal ,
1117- causal_mask_diagonal ,
1118- softclamp_qk_sim ,
1119- softclamp_value ,
11201016 BLOCK_HEADDIM ,
11211017 # SEQUENCE_PARALLEL=False,
11221018 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
@@ -1142,10 +1038,33 @@ def forward(
11421038 selected_block_indices ,
11431039 num_grouped_queries
11441040 ):
1145- raise NotImplementedError
1041+ fq , fk , fv = tuple (rearrange (t , 'b h n d -> b n h d' ) for t in (fq , fk , fv ))
1042+
1043+ dtype = fq .dtype
1044+
1045+ fq , fk , fv = tuple (t .half () for t in (fq , fk , fv ))
1046+
1047+ out , lse = flash_attn_forward (fq , fk , fv )
1048+
1049+ ctx .save_for_backward (fq , fk , fv , out , lse )
1050+
1051+ out = rearrange (out , 'b n h d -> b h n d' )
1052+ return out .type (dtype )
11461053
11471054 @classmethod
11481055 def backward (self , ctx , do ):
1149- raise NotImplementedError
1056+ do = rearrange (do , 'b h n d -> b n h d' )
1057+
1058+ q , k , v , out , lse = ctx .saved_tensors
1059+
1060+ do = do .half ()
1061+ dq = torch .zeros_like (q )
1062+ dk = torch .zeros_like (k )
1063+ dv = torch .zeros_like (v )
1064+
1065+ flash_attn_backward (do , q , k , v , out , lse , dq , dk , dv )
1066+
1067+ dq , dk , dv = tuple (rearrange (t , 'b n h d -> b h n d' ) for t in (dq , dk , dv ))
1068+ return dq , dk , dv , None , None , None
11501069
11511070native_sparse_attend = NSA .apply
0 commit comments