@@ -56,7 +56,6 @@ def _fwd_kernel(
5656 Q ,
5757 K ,
5858 V ,
59- Bias ,
6059 Out ,
6160 M ,
6261 Lse ,
@@ -70,9 +69,6 @@ def _fwd_kernel(
7069 stride_vb ,
7170 stride_vh ,
7271 stride_vn ,
73- stride_bb ,
74- stride_bh ,
75- stride_bm ,
7672 stride_ob ,
7773 stride_oh ,
7874 stride_om ,
@@ -83,7 +79,6 @@ def _fwd_kernel(
8379 headdim ,
8480 CACHE_KEY_SEQLEN_Q ,
8581 CACHE_KEY_SEQLEN_K ,
86- HAS_BIAS : tl .constexpr ,
8782 BLOCK_HEADDIM : tl .constexpr ,
8883 EVEN_M : tl .constexpr ,
8984 EVEN_N : tl .constexpr ,
@@ -110,9 +105,6 @@ def _fwd_kernel(
110105 V + off_b * stride_vb + off_h * stride_vh + (offs_n [:, None ] * stride_vn + offs_d [None , :])
111106 )
112107
113- if HAS_BIAS :
114- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
115-
116108 # maximum
117109
118110 m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
@@ -183,22 +175,8 @@ def _fwd_kernel(
183175
184176 qk += tl .where (offs_m [:, None ] >= (start_n + offs_n )[None , :], 0 , float ("-inf" ))
185177
186- if HAS_BIAS :
187- if EVEN_N :
188- bias = tl .load (b_ptrs + start_n )
189- else :
190- bias = tl .load (
191- b_ptrs + start_n , mask = (start_n + offs_n ) < seqlen_k , other = 0.0
192- )
193- bias = bias [None , :]
194-
195- bias = bias .to (tl .float32 )
196- qk = qk * softmax_scale + bias
197- m_ij = tl .maximum (tl .max (qk , 1 ), lse_i )
198- p = tl .exp (qk - m_ij [:, None ])
199- else :
200- m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
201- p = tl .exp (qk * softmax_scale - m_ij [:, None ])
178+ m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
179+ p = tl .exp (qk * softmax_scale - m_ij [:, None ])
202180
203181 l_ij = tl .sum (p , 1 )
204182
@@ -264,7 +242,6 @@ def flash_attn_forward(
264242 q ,
265243 k ,
266244 v ,
267- bias = None ,
268245 o = None ,
269246 m = None ,
270247 lse = None ,
@@ -285,23 +262,6 @@ def flash_attn_forward(
285262
286263 softmax_scale = default (softmax_scale , d ** - 0.5 )
287264
288- has_bias = exists (bias )
289-
290- if has_bias :
291- assert bias .dtype in [q .dtype , torch .float ]
292- assert bias .is_cuda
293-
294- if bias .ndim == 2 :
295- bias = repeat (bias , 'b j -> b h i j' , h = nheads , i = seqlen_q )
296-
297- if not is_contiguous (bias ):
298- bias = bias .contiguous ()
299-
300- assert bias .shape [- 2 :] == (seqlen_q , seqlen_k )
301- bias = bias .expand (batch , nheads , seqlen_q , seqlen_k )
302-
303- bias_strides = (bias .stride (0 ), bias .stride (1 ), bias .stride (2 )) if has_bias else (0 , 0 , 0 )
304-
305265 seqlen_q_rounded = ceil (seqlen_q / 128 ) * 128
306266
307267 if not exists (lse ):
@@ -324,7 +284,6 @@ def flash_attn_forward(
324284 q ,
325285 k ,
326286 v ,
327- bias ,
328287 o ,
329288 m ,
330289 lse ,
@@ -338,7 +297,6 @@ def flash_attn_forward(
338297 v .stride (0 ),
339298 v .stride (2 ),
340299 v .stride (1 ),
341- * bias_strides ,
342300 o .stride (0 ),
343301 o .stride (2 ),
344302 o .stride (1 ),
@@ -349,7 +307,6 @@ def flash_attn_forward(
349307 d ,
350308 seqlen_q // 32 ,
351309 seqlen_k // 32 ,
352- has_bias ,
353310 BLOCK_HEADDIM ,
354311 BLOCK_M = BLOCK ,
355312 BLOCK_N = BLOCK ,
@@ -445,7 +402,6 @@ def _bwd_kernel_one_col_block(
445402 Q ,
446403 K ,
447404 V ,
448- Bias ,
449405 DO ,
450406 DQ ,
451407 DK ,
@@ -456,7 +412,6 @@ def _bwd_kernel_one_col_block(
456412 stride_qm ,
457413 stride_kn ,
458414 stride_vn ,
459- stride_bm ,
460415 stride_dom ,
461416 stride_dqm ,
462417 stride_dkn ,
@@ -465,7 +420,6 @@ def _bwd_kernel_one_col_block(
465420 seqlen_k ,
466421 headdim ,
467422 ATOMIC_ADD : tl .constexpr ,
468- BIAS_TYPE : tl .constexpr ,
469423 BLOCK_HEADDIM : tl .constexpr ,
470424 EVEN_M : tl .constexpr ,
471425 EVEN_N : tl .constexpr ,
@@ -486,10 +440,7 @@ def _bwd_kernel_one_col_block(
486440 v_ptrs = V + (offs_n [:, None ] * stride_vn + offs_d [None , :])
487441 do_ptrs = DO + (offs_qm [:, None ] * stride_dom + offs_d [None , :])
488442 dq_ptrs = DQ + (offs_qm [:, None ] * stride_dqm + offs_d [None , :])
489- if BIAS_TYPE == "vector" :
490- b_ptrs = Bias + offs_n
491- elif BIAS_TYPE == "matrix" :
492- b_ptrs = Bias + (offs_qm [:, None ] * stride_bm + offs_n [None , :])
443+
493444 # initialize dv and dk
494445 dv = tl .zeros ([BLOCK_N , BLOCK_HEADDIM ], dtype = tl .float32 )
495446 dk = tl .zeros ([BLOCK_N , BLOCK_HEADDIM ], dtype = tl .float32 )
@@ -562,33 +513,14 @@ def _bwd_kernel_one_col_block(
562513
563514 qk = tl .where (offs_m_curr [:, None ] >= (offs_n [None , :]), qk , float ("-inf" ))
564515
565- if BIAS_TYPE != "none" :
566- tl .debug_barrier () # Race condition otherwise
567- if BIAS_TYPE == "vector" :
568- if EVEN_N :
569- bias = tl .load (b_ptrs ).to (tl .float32 )
570- else :
571- bias = tl .load (b_ptrs , mask = offs_n < seqlen_k , other = 0.0 ).to (tl .float32 )
572- bias = bias [None , :]
573- elif BIAS_TYPE == "matrix" :
574- if EVEN_M & EVEN_N :
575- bias = tl .load (b_ptrs ).to (tl .float32 )
576- else :
577- bias = tl .load (
578- b_ptrs ,
579- mask = (offs_m_curr [:, None ] < seqlen_q ) & (offs_n [None , :] < seqlen_k ),
580- other = 0.0 ,
581- ).to (tl .float32 )
582- qk = qk * softmax_scale + bias
583516 # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
584517 # Also wrong for headdim=64.
585518 if not (EVEN_M & EVEN_HEADDIM ):
586519 tl .debug_barrier ()
587520 lse_i = tl .load (LSE + offs_m_curr )
588- if BIAS_TYPE == "none" :
589- p = tl .exp (qk * softmax_scale - lse_i [:, None ])
590- else :
591- p = tl .exp (qk - lse_i [:, None ])
521+
522+ p = tl .exp (qk * softmax_scale - lse_i [:, None ])
523+
592524 # compute dv
593525 # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
594526 # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
@@ -693,8 +625,7 @@ def _bwd_kernel_one_col_block(
693625 dq_ptrs += BLOCK_M * stride_dqm
694626 q_ptrs += BLOCK_M * stride_qm
695627 do_ptrs += BLOCK_M * stride_dom
696- if BIAS_TYPE == "matrix" :
697- b_ptrs += BLOCK_M * stride_bm
628+
698629 # write-back
699630 dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
700631 dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
@@ -738,7 +669,7 @@ def init_to_zero(name):
738669 # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
739670 # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
740671 ],
741- key = ["CACHE_KEY_SEQLEN_Q" , "CACHE_KEY_SEQLEN_K" , "BIAS_TYPE" , " BLOCK_HEADDIM" ],
672+ key = ["CACHE_KEY_SEQLEN_Q" , "CACHE_KEY_SEQLEN_K" , "BLOCK_HEADDIM" ],
742673)
743674@triton .heuristics (
744675 {
@@ -752,7 +683,6 @@ def _bwd_kernel(
752683 Q ,
753684 K ,
754685 V ,
755- Bias ,
756686 DO ,
757687 DQ ,
758688 DK ,
@@ -769,9 +699,6 @@ def _bwd_kernel(
769699 stride_vb ,
770700 stride_vh ,
771701 stride_vn ,
772- stride_bb ,
773- stride_bh ,
774- stride_bm ,
775702 stride_dob ,
776703 stride_doh ,
777704 stride_dom ,
@@ -791,7 +718,6 @@ def _bwd_kernel(
791718 headdim ,
792719 CACHE_KEY_SEQLEN_Q ,
793720 CACHE_KEY_SEQLEN_K ,
794- BIAS_TYPE : tl .constexpr ,
795721 BLOCK_HEADDIM : tl .constexpr ,
796722 SEQUENCE_PARALLEL : tl .constexpr ,
797723 EVEN_M : tl .constexpr ,
@@ -811,8 +737,6 @@ def _bwd_kernel(
811737 DQ += off_b * stride_dqb + off_h * stride_dqh
812738 DK += off_b * stride_dkb + off_h * stride_dkh
813739 DV += off_b * stride_dvb + off_h * stride_dvh
814- if BIAS_TYPE != "none" :
815- Bias += off_b * stride_bb + off_h * stride_bh
816740 # pointer to row-wise quantities in value-like data
817741 D += off_hb * seqlen_q_rounded
818742 LSE += off_hb * seqlen_q_rounded
@@ -824,7 +748,6 @@ def _bwd_kernel(
824748 Q ,
825749 K ,
826750 V ,
827- Bias ,
828751 DO ,
829752 DQ ,
830753 DK ,
@@ -835,7 +758,6 @@ def _bwd_kernel(
835758 stride_qm ,
836759 stride_kn ,
837760 stride_vn ,
838- stride_bm ,
839761 stride_dom ,
840762 stride_dqm ,
841763 stride_dkn ,
@@ -844,7 +766,6 @@ def _bwd_kernel(
844766 seqlen_k ,
845767 headdim ,
846768 ATOMIC_ADD = False ,
847- BIAS_TYPE = BIAS_TYPE ,
848769 BLOCK_HEADDIM = BLOCK_HEADDIM ,
849770 EVEN_M = EVEN_M ,
850771 EVEN_N = EVEN_N ,
@@ -859,7 +780,6 @@ def _bwd_kernel(
859780 Q ,
860781 K ,
861782 V ,
862- Bias ,
863783 DO ,
864784 DQ ,
865785 DK ,
@@ -870,7 +790,6 @@ def _bwd_kernel(
870790 stride_qm ,
871791 stride_kn ,
872792 stride_vn ,
873- stride_bm ,
874793 stride_dom ,
875794 stride_dqm ,
876795 stride_dkn ,
@@ -879,7 +798,6 @@ def _bwd_kernel(
879798 seqlen_k ,
880799 headdim ,
881800 ATOMIC_ADD = True ,
882- BIAS_TYPE = BIAS_TYPE ,
883801 BLOCK_HEADDIM = BLOCK_HEADDIM ,
884802 EVEN_M = EVEN_M ,
885803 EVEN_N = EVEN_N ,
@@ -899,7 +817,6 @@ def flash_attn_backward(
899817 dk ,
900818 dv ,
901819 delta = None ,
902- bias = None ,
903820 softmax_scale = None ,
904821):
905822 # Make sure that the last dimension is contiguous
@@ -944,24 +861,6 @@ def flash_attn_backward(
944861 BLOCK_HEADDIM = BLOCK_HEADDIM ,
945862 )
946863
947- has_bias = bias is not None
948- bias_type = "none"
949- if has_bias :
950- assert bias .dtype in [q .dtype , torch .float ]
951- assert bias .is_cuda
952- assert bias .dim () == 4
953- assert bias .stride (- 1 ) == 1
954- if bias .shape [2 :] == (1 , seqlen_k ):
955- bias_type = "vector"
956- elif bias .shape [2 :] == (seqlen_q , seqlen_k ):
957- bias_type = "matrix"
958- else :
959- raise RuntimeError (
960- "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
961- )
962- bias = bias .expand (batch , nheads , seqlen_q , seqlen_k )
963- bias_strides = (bias .stride (0 ), bias .stride (1 ), bias .stride (2 )) if has_bias else (0 , 0 , 0 )
964-
965864 # BLOCK_M = 128
966865 # BLOCK_N = 64
967866 # num_warps = 4
@@ -973,7 +872,6 @@ def flash_attn_backward(
973872 q ,
974873 k ,
975874 v ,
976- bias ,
977875 do ,
978876 dq_accum ,
979877 dk ,
@@ -990,7 +888,6 @@ def flash_attn_backward(
990888 v .stride (0 ),
991889 v .stride (2 ),
992890 v .stride (1 ),
993- * bias_strides ,
994891 do .stride (0 ),
995892 do .stride (2 ),
996893 do .stride (1 ),
@@ -1012,7 +909,6 @@ def flash_attn_backward(
1012909 seqlen_k // 32 , # key for triton cache (limit number of compilations)
1013910 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
1014911 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1015- bias_type ,
1016912 BLOCK_HEADDIM ,
1017913 # SEQUENCE_PARALLEL=False,
1018914 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
0 commit comments