@@ -111,13 +111,13 @@ def _fwd_kernel(
111111
112112 m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
113113
114- m_i = tl .zeros ([BLOCK ], dtype = tl .float32 ) - float ("inf" )
114+ m_i = tl .zeros ([BLOCK ], dtype = tl .float32 ) - float ("inf" )
115115
116116 # lse
117117
118118 lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
119119
120- lse_i = tl .zeros ([BLOCK ], dtype = tl .float32 ) - float ("inf" )
120+ lse_i = tl .zeros ([BLOCK ], dtype = tl .float32 ) - float ("inf" )
121121
122122 # output
123123
@@ -130,7 +130,7 @@ def _fwd_kernel(
130130 + (offs_m [:, None ] * stride_om + offs_d [None , :])
131131 )
132132
133- acc_o = tl .zeros ([BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
133+ acc_o = tl .zeros ([BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
134134
135135 # load queries, keys, values
136136
@@ -243,6 +243,8 @@ def flash_attn_forward(
243243 q ,
244244 k ,
245245 v ,
246+ indices ,
247+ mask ,
246248 block_size = 128
247249):
248250 q , k , v = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v )]
@@ -328,15 +330,20 @@ def _bwd_preprocess_do_o_dot(
328330 off_hb = tl .program_id (1 )
329331 off_b = off_hb // nheads
330332 off_h = off_hb % nheads
333+
331334 # initialize offsets
335+
332336 offs_m = start_m * BLOCK + tl .arange (0 , BLOCK )
333337 offs_d = tl .arange (0 , BLOCK_HEADDIM )
338+
334339 # load
340+
335341 o = tl .load (
336342 Out + off_b * stride_ob + off_h * stride_oh + offs_m [:, None ] * stride_om + offs_d [None , :],
337343 mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
338344 other = 0.0 ,
339345 ).to (tl .float32 )
346+
340347 do = tl .load (
341348 DO
342349 + off_b * stride_dob
@@ -346,8 +353,11 @@ def _bwd_preprocess_do_o_dot(
346353 mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
347354 other = 0.0 ,
348355 ).to (tl .float32 )
356+
349357 delta = tl .sum (o * do , axis = 1 )
358+
350359 # write-back
360+
351361 tl .store (Delta + off_hb * seqlen_q_rounded + offs_m , delta )
352362
353363@triton .jit
@@ -538,22 +548,31 @@ def _bwd_kernel_one_col_block(
538548 # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
539549 if not (EVEN_M & EVEN_HEADDIM ):
540550 tl .debug_barrier ()
551+
541552 dp = tl .dot (do , tl .trans (v ))
553+
542554 # There's a race condition for headdim=48
543555 if not EVEN_HEADDIM :
544556 tl .debug_barrier ()
557+
545558 # compute ds = p * (dp - delta[:, None])
546559 # Putting the subtraction after the dp matmul (instead of before) is slightly faster
560+
547561 Di = tl .load (D + offs_m )
562+
548563 # Converting ds to q.dtype here reduces register pressure and makes it much faster
549564 # for BLOCK_HEADDIM=128
565+
550566 ds = (p * (dp - Di [:, None ]) * softmax_scale )
551567
552568 ds = ds .to (q .dtype )
553569
554570 # compute dk = dot(ds.T, q)
571+
555572 dk += tl .dot (tl .trans (ds ), q )
573+
556574 # compute dq
575+
557576 if not (
558577 EVEN_M & EVEN_HEADDIM
559578 ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
@@ -613,6 +632,7 @@ def _bwd_kernel_one_col_block(
613632 # do_ptrs += BLOCK * stride_dom
614633
615634 # write-back
635+
616636 dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
617637 dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
618638 _bwd_store_dk_dv (
@@ -756,14 +776,12 @@ def _bwd_kernel(
756776
757777def flash_attn_backward (
758778 do ,
759- q ,
760- k ,
761- v ,
779+ q , k , v ,
780+ indices ,
781+ mask ,
762782 o ,
763783 lse ,
764- dq ,
765- dk ,
766- dv ,
784+ dq , dk , dv ,
767785 block_size = 128
768786):
769787 # Make sure that the last dimension is contiguous
@@ -805,7 +823,7 @@ def flash_attn_backward(
805823 seqlen_q_rounded ,
806824 dim ,
807825 BLOCK = block_size ,
808- BLOCK_HEADDIM = BLOCK_HEADDIM ,
826+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
809827 )
810828
811829 grid = lambda META : (
@@ -889,7 +907,12 @@ def forward(
889907
890908 fq , fk , fv = tuple (t .half () for t in (fq , fk , fv ))
891909
892- out , lse = flash_attn_forward (fq , fk , fv , block_size = block_size )
910+ out , lse = flash_attn_forward (
911+ fq , fk , fv ,
912+ selected_block_indices ,
913+ fmask ,
914+ block_size = block_size
915+ )
893916
894917 ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
895918 ctx ._saved_variables = (block_size ,)
@@ -901,7 +924,7 @@ def forward(
901924 def backward (self , ctx , do ):
902925 do = rearrange (do , 'b h n d -> b n h d' )
903926
904- q , k , v , kv_indices , mask , out , lse = ctx .saved_tensors
927+ q , k , v , sel_block_indices , mask , out , lse = ctx .saved_tensors
905928
906929 (
907930 block_size ,
@@ -912,7 +935,12 @@ def backward(self, ctx, do):
912935 dk = torch .zeros_like (k )
913936 dv = torch .zeros_like (v )
914937
915- flash_attn_backward (do , q , k , v , out , lse , dq , dk , dv , block_size = block_size )
938+ flash_attn_backward (
939+ do , q , k , v ,
940+ sel_block_indices , mask ,
941+ out , lse , dq , dk , dv ,
942+ block_size = block_size
943+ )
916944
917945 dq , dk , dv = tuple (rearrange (t , 'b n h d -> b h n d' ) for t in (dq , dk , dv ))
918946 return dq , dk , dv , None , None , None , None
0 commit comments