@@ -677,20 +677,9 @@ def backward_kernel_one_col_block(
677677 ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
678678 tl .debug_barrier ()
679679
680- dq = tl .dot ( ds , k )
680+ dq = tl .zeros ([ BLOCK , BLOCK_HEADDIM ], dtype = tl . float32 )
681681
682- if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
683- tl .atomic_add (dq_ptrs , dq , sem = 'relaxed' )
684- else :
685- if EVEN_HEADDIM :
686- tl .atomic_add (dq_ptrs , dq , mask = offs_m [:, None ] < seqlen_q , sem = 'relaxed' )
687- else :
688- tl .atomic_add (
689- dq_ptrs ,
690- dq ,
691- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
692- sem = 'relaxed' ,
693- )
682+ dq += tl .dot (ds , k )
694683
695684 # handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
696685
@@ -774,7 +763,22 @@ def backward_kernel_one_col_block(
774763 block_dq = tl .dot (ds_expanded , block_k )
775764 block_dq = tl .sum (block_dq , 1 ) / 16
776765
777- tl .atomic_add (dq_ptrs , block_dq , sem = 'relaxed' )
766+ dq += block_dq
767+
768+ # update dq
769+
770+ if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
771+ tl .atomic_add (dq_ptrs , dq , sem = 'relaxed' )
772+ else :
773+ if EVEN_HEADDIM :
774+ tl .atomic_add (dq_ptrs , dq , mask = offs_m [:, None ] < seqlen_q , sem = 'relaxed' )
775+ else :
776+ tl .atomic_add (
777+ dq_ptrs ,
778+ dq ,
779+ mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
780+ sem = 'relaxed' ,
781+ )
778782
779783 # # increment pointers
780784 # dq_ptrs += BLOCK * stride_dqm
0 commit comments