@@ -676,54 +676,21 @@ def backward_kernel_one_col_block(
676676 EVEN_M & EVEN_HEADDIM
677677 ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
678678 tl .debug_barrier ()
679- if not ATOMIC_ADD :
680- if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
681- dq = tl .load (dq_ptrs , eviction_policy = "evict_last" )
682- dq += tl .dot (ds , k )
683- tl .store (dq_ptrs , dq , eviction_policy = "evict_last" )
684- else :
685- if EVEN_HEADDIM :
686- dq = tl .load (
687- dq_ptrs ,
688- mask = offs_m [:, None ] < seqlen_q ,
689- other = 0.0 ,
690- eviction_policy = "evict_last" ,
691- )
692- dq += tl .dot (ds , k )
693- tl .store (
694- dq_ptrs ,
695- dq ,
696- mask = offs_m [:, None ] < seqlen_q ,
697- eviction_policy = "evict_last" ,
698- )
699- else :
700- dq = tl .load (
701- dq_ptrs ,
702- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
703- other = 0.0 ,
704- eviction_policy = "evict_last" ,
705- )
706- dq += tl .dot (ds , k )
707- tl .store (
708- dq_ptrs ,
709- dq ,
710- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
711- eviction_policy = "evict_last" ,
712- )
713- else : # If we're parallelizing across the seqlen_k dimension
714- dq = tl .dot (ds , k )
715- if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
716- tl .atomic_add (dq_ptrs , dq , sem = 'relaxed' )
679+
680+ dq = tl .dot (ds , k )
681+
682+ if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
683+ tl .atomic_add (dq_ptrs , dq , sem = 'relaxed' )
684+ else :
685+ if EVEN_HEADDIM :
686+ tl .atomic_add (dq_ptrs , dq , mask = offs_m [:, None ] < seqlen_q , sem = 'relaxed' )
717687 else :
718- if EVEN_HEADDIM :
719- tl .atomic_add (dq_ptrs , dq , mask = offs_m [:, None ] < seqlen_q , sem = 'relaxed' )
720- else :
721- tl .atomic_add (
722- dq_ptrs ,
723- dq ,
724- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
725- sem = 'relaxed' ,
726- )
688+ tl .atomic_add (
689+ dq_ptrs ,
690+ dq ,
691+ mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
692+ sem = 'relaxed' ,
693+ )
727694
728695 # handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
729696
@@ -765,8 +732,8 @@ def backward_kernel_one_col_block(
765732 q_expanded = tl .expand_dims (q , 1 )
766733 q_expanded = tl .broadcast_to (q_expanded , (BLOCK , 16 , BLOCK_HEADDIM ))
767734
768- block_k = tl .permute (block_k , (0 , 2 , 1 ))
769- block_qk = tl .dot (q_expanded , block_k )
735+ block_k_permuted = tl .permute (block_k , (0 , 2 , 1 ))
736+ block_qk = tl .dot (q_expanded , block_k_permuted )
770737
771738 qk = tl .sum (block_qk , 1 ) / 16.
772739 qk += tl .where (block_masks [:, None ], 0 , float ("-inf" ))
@@ -800,6 +767,15 @@ def backward_kernel_one_col_block(
800767
801768 tl .atomic_add (block_dk_ptrs , block_dk , sem = 'relaxed' )
802769
770+ # block dq
771+
772+ ds_expanded = tl .expand_dims (ds , 1 )
773+ ds_expanded = tl .broadcast_to (ds_expanded , (BLOCK , 16 , BLOCK ))
774+ block_dq = tl .dot (ds_expanded , block_k )
775+ block_dq = tl .sum (block_dq , 1 ) / 16
776+
777+ tl .atomic_add (dq_ptrs , block_dq , sem = 'relaxed' )
778+
803779 # # increment pointers
804780 # dq_ptrs += BLOCK * stride_dqm
805781 # q_ptrs += BLOCK * stride_qm
0 commit comments