@@ -473,18 +473,18 @@ def backward_store_dk_dv(
473473 # if we just call tl.store(dv_ptrs), there's a race condition
474474 if EVEN_N & EVEN_M :
475475 if EVEN_HEADDIM :
476- tl .store (dv_ptrs , dv )
477- tl .store (dk_ptrs , dk )
476+ tl .atomic_add (dv_ptrs , dv , sem = 'relaxed' )
477+ tl .atomic_add (dk_ptrs , dk , sem = 'relaxed' )
478478 else :
479- tl .store (dv_ptrs , dv , mask = offs_d [None , :] < headdim )
480- tl .store (dk_ptrs , dk , mask = offs_d [None , :] < headdim )
479+ tl .atomic_add (dv_ptrs , dv , mask = offs_d [None , :] < headdim , sem = 'relaxed' )
480+ tl .atomic_add (dk_ptrs , dk , mask = offs_d [None , :] < headdim , sem = 'relaxed' )
481481 else :
482482 if EVEN_HEADDIM :
483- tl .store (dv_ptrs , dv , mask = offs_n [:, None ] < seqlen_k )
484- tl .store (dk_ptrs , dk , mask = offs_n [:, None ] < seqlen_k )
483+ tl .atomic_add (dv_ptrs , dv , mask = offs_n [:, None ] < seqlen_k , sem = 'relaxed' )
484+ tl .atomic_add (dk_ptrs , dk , mask = offs_n [:, None ] < seqlen_k , sem = 'relaxed' )
485485 else :
486- tl .store (dv_ptrs , dv , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ))
487- tl .store (dk_ptrs , dk , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ))
486+ tl .atomic_add (dv_ptrs , dv , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), sem = 'relaxed' )
487+ tl .atomic_add (dk_ptrs , dk , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), sem = 'relaxed' )
488488
489489
490490@triton .jit
@@ -751,6 +751,14 @@ def backward_kernel_one_col_block(
751751 V + blocks_offs_n [:, :, None ] * stride_vn + offs_d [None , None , :]
752752 )
753753
754+ block_dv_ptrs = (
755+ DV + blocks_offs_n [:, :, None ] * stride_dvn + offs_d [None , None , :]
756+ )
757+
758+ block_dk_ptrs = (
759+ DK + blocks_offs_n [:, :, None ] * stride_dkn + offs_d [None , None , :]
760+ )
761+
754762 block_k = tl .load (block_k_ptrs )
755763 block_v = tl .load (block_v_ptrs )
756764
@@ -765,6 +773,11 @@ def backward_kernel_one_col_block(
765773
766774 p = tl .exp (qk * softmax_scale - lse_i [:, None ])
767775
776+ block_dv = p .to (do .dtype )[:, :, None ] * do [:, None , :]
777+ block_dv = tl .where (block_masks [:, None , None ], block_dv , 0. )
778+
779+ tl .atomic_add (block_dv_ptrs , block_dv , sem = 'relaxed' )
780+
768781 # # increment pointers
769782 # dq_ptrs += BLOCK * stride_dqm
770783 # q_ptrs += BLOCK * stride_qm
@@ -965,6 +978,8 @@ def flash_attn_backward(
965978 softmax_scale = dim ** - 0.5
966979
967980 dq_accum = torch .empty_like (q , dtype = torch .float32 )
981+ dk_accum = torch .empty_like (k , dtype = torch .float32 )
982+ dv_accum = torch .empty_like (v , dtype = torch .float32 )
968983
969984 # delta = torch.zeros_like(lse)
970985
@@ -995,6 +1010,7 @@ def flash_attn_backward(
9951010 triton .cdiv (seqlen_k , META ["BLOCK" ]) if META ["SEQUENCE_PARALLEL" ] else 1 ,
9961011 batch * nheads ,
9971012 )
1013+
9981014 backward_kernel [grid ](
9991015 q ,
10001016 k ,
@@ -1003,8 +1019,8 @@ def flash_attn_backward(
10031019 kv_block_mask ,
10041020 do ,
10051021 dq_accum ,
1006- dk ,
1007- dv ,
1022+ dk_accum ,
1023+ dv_accum ,
10081024 lse ,
10091025 delta ,
10101026 softmax_scale ,
@@ -1052,7 +1068,10 @@ def flash_attn_backward(
10521068 # num_warps=num_warps,
10531069 # num_stages=1,
10541070 )
1071+
10551072 dq .copy_ (dq_accum )
1073+ dk .copy_ (dk_accum )
1074+ dv .copy_ (dv_accum )
10561075
10571076 return delta
10581077
0 commit comments