@@ -587,7 +587,7 @@ def backward_kernel_one_col_block_sparse(
587587 QUERY_EXPAND_DIM : tl .constexpr ,
588588 OFF_SEL_KV_BLOCKS : tl .constexpr
589589):
590- # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
590+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
591591
592592 begin_m = ((start_n * BLOCK ) // BLOCK ) * BLOCK
593593
@@ -609,73 +609,32 @@ def backward_kernel_one_col_block_sparse(
609609
610610 q_ptrs = (
611611 Q +
612- offs_g [:, None , None ] * stride_qh +
613- offs_qm [None , : , None ] * stride_qm +
612+ offs_g [None , : , None ] * stride_qh +
613+ offs_qm [:, None , None ] * stride_qm +
614614 offs_d [None , None , :]
615615 )
616616
617617 do_ptrs = (
618618 DO +
619- offs_g [:, None , None ] * stride_doh +
620- offs_qm [None , : , None ] * stride_dom +
619+ offs_g [None , : , None ] * stride_doh +
620+ offs_qm [:, None , None ] * stride_dom +
621621 offs_d [None , None , :]
622622 )
623623
624624 dq_ptrs = (
625625 DQ +
626- offs_g [:, None , None ] * stride_dqh +
627- offs_qm [None , : , None ] * stride_dqm +
626+ offs_g [None , : , None ] * stride_dqh +
627+ offs_qm [:, None , None ] * stride_dqm +
628628 offs_d [None , None , :]
629629 )
630630
631- # initialize dv and dk
632-
633- dv = tl .zeros ([BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
634- dk = tl .zeros ([BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
635-
636631 # There seems to be some problem with Triton pipelining that makes results wrong for
637632 # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
638633 # may have zero step, and pipelining with the bias matrix could screw it up.
639634 # So we just exit early.
640635
641636 if begin_m >= seqlen_q :
642- dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
643- dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
644- backward_store_dk_dv (
645- dk_ptrs ,
646- dv_ptrs ,
647- dk ,
648- dv ,
649- offs_n ,
650- offs_d ,
651- seqlen_k ,
652- headdim ,
653- EVEN_M = EVEN_M ,
654- EVEN_N = EVEN_N ,
655- EVEN_HEADDIM = EVEN_HEADDIM ,
656- )
657637 return
658- # k and v stay in SRAM throughout
659- # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
660- # if we just call tl.load(k_ptrs), we get the wrong output!
661- if EVEN_N & EVEN_M :
662- if EVEN_HEADDIM :
663- k = tl .load (k_ptrs )
664- v = tl .load (v_ptrs )
665- else :
666- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
667- v = tl .load (v_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
668- else :
669- if EVEN_HEADDIM :
670- k = tl .load (k_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
671- v = tl .load (v_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
672- else :
673- k = tl .load (
674- k_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
675- )
676- v = tl .load (
677- v_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
678- )
679638
680639 # same block for block causal diagonal
681640
@@ -685,28 +644,17 @@ def backward_kernel_one_col_block_sparse(
685644 q = tl .load (q_ptrs )
686645 else :
687646 if EVEN_HEADDIM :
688- q = tl .load (q_ptrs , mask = offs_m [None , :, None ] < seqlen_q , other = 0.0 )
647+ q = tl .load (
648+ q_ptrs ,
649+ mask = offs_m [:, None , None ] < seqlen_q ,
650+ other = 0.0
651+ )
689652 else :
690653 q = tl .load (
691654 q_ptrs ,
692- mask = (offs_m [None , : , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
693- other = 0.0 ,
655+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
656+ other = 0.0 ,
694657 )
695- # recompute p = softmax(qk, dim=-1).T
696-
697- q = q .reshape ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ])
698-
699- qk = tl .dot (q , tl .trans (k ))
700-
701- qk = qk .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
702-
703- # Trying to combine the two masks seem to make the result wrong
704- if not EVEN_N : # Need to mask out otherwise the softmax is wrong
705- qk = tl .where (offs_n [None , :] < seqlen_k , qk , float ("-inf" ))
706-
707- qk = tl .where (offs_m [:, None ] >= (offs_n [None , :]), qk , float ("-inf" ))
708-
709- qk = qk .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
710658
711659 # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
712660 # Also wrong for headdim=64.
@@ -715,9 +663,7 @@ def backward_kernel_one_col_block_sparse(
715663 tl .debug_barrier ()
716664
717665 lse_i = tl .load (LSE + offs_d_or_lse )
718- lse_i = lse_i .reshape (QUERY_HEAD_GROUPS * BLOCK )
719-
720- p = tl .exp (qk * softmax_scale - lse_i [:, None ])
666+ lse_i = tl .trans (lse_i ) # (m, h)
721667
722668 # compute dv
723669 # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
@@ -730,12 +676,10 @@ def backward_kernel_one_col_block_sparse(
730676 # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
731677 do = tl .load (
732678 do_ptrs ,
733- mask = (offs_m [None , : , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
734- other = 0.0 ,
679+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
680+ other = 0.0 ,
735681 )
736682
737- do = do .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM )
738-
739683 # compute dp = dot(v, do)
740684 # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
741685 # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
@@ -752,12 +696,12 @@ def backward_kernel_one_col_block_sparse(
752696 # Putting the subtraction after the dp matmul (instead of before) is slightly faster
753697
754698 Di = tl .load (D + offs_d_or_lse )
755- Di = Di . reshape ( QUERY_HEAD_GROUPS * BLOCK )
699+ Di = tl . trans ( Di ) # (m, h )
756700
757701 # Converting ds to q.dtype here reduces register pressure and makes it much faster
758702 # for BLOCK_HEADDIM=128
759703
760- dq = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
704+ dq = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM ], dtype = tl .float32 )
761705
762706 # handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
763707
@@ -774,30 +718,39 @@ def backward_kernel_one_col_block_sparse(
774718 block_indices = tl .load (kv_block_indices_ptrs + OFF_SEL_KV_BLOCKS )
775719 block_masks = tl .load (kv_block_mask_ptrs + OFF_SEL_KV_BLOCKS )
776720
777- blocks_offs_n = block_indices [:, None ] * BLOCK + tl .arange (0 , BLOCK )[None , :]
721+ blocks_offs_n = (
722+ block_indices [:, None ] * BLOCK +
723+ tl .arange (0 , BLOCK )[None , :]
724+ )
778725
779726 block_k_ptrs = (
780- K + blocks_offs_n [:, :, None ] * stride_kn + offs_d [None , None , :]
727+ K +
728+ blocks_offs_n [:, :, None ] * stride_kn +
729+ offs_d [None , None , :]
781730 )
782731
783732 block_v_ptrs = (
784- V + blocks_offs_n [:, :, None ] * stride_vn + offs_d [None , None , :]
733+ V +
734+ blocks_offs_n [:, :, None ] * stride_vn +
735+ offs_d [None , None , :]
785736 )
786737
787738 block_dv_ptrs = (
788- DV + blocks_offs_n [:, :, None ] * stride_dvn + offs_d [None , None , :]
739+ DV +
740+ blocks_offs_n [:, :, None ] * stride_dvn +
741+ offs_d [None , None , :]
789742 )
790743
791744 block_dk_ptrs = (
792- DK + blocks_offs_n [:, :, None ] * stride_dkn + offs_d [None , None , :]
745+ DK +
746+ blocks_offs_n [:, :, None ] * stride_dkn +
747+ offs_d [None , None , :]
793748 )
794749
795750 block_k = tl .load (block_k_ptrs )
796751 block_v = tl .load (block_v_ptrs )
797752
798- q_expanded = q .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
799- q_expanded = q_expanded .permute (1 , 0 , 2 )
800- q_expanded = tl .expand_dims (q_expanded , 2 )
753+ q_expanded = tl .expand_dims (q , 2 )
801754 q_expanded = tl .broadcast_to (q_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
802755 q_expanded = q_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
803756
@@ -806,84 +759,77 @@ def backward_kernel_one_col_block_sparse(
806759
807760 block_qk = block_qk .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
808761 qk = tl .sum (block_qk , 2 ) / QUERY_EXPAND_DIM
809- qk = qk .permute (1 , 0 , 2 )
810-
811- qk += tl .where (block_masks [None , :, None ], 0 , float ("-inf" ))
812762
813- qk = qk . reshape ( QUERY_HEAD_GROUPS * BLOCK , BLOCK )
763+ qk += tl . where ( block_masks [:, None , None ], 0 , float ( "-inf" ) )
814764
815- p = tl .exp (qk * softmax_scale - lse_i [:, None ])
765+ p = tl .exp (qk * softmax_scale - lse_i [:, :, None ])
816766
817767 # take care of block dv
818768
819- block_dv = p .to (do .dtype )[:, :, None ] * do [:, None , :]
769+ block_dv = p .to (do .dtype )[:, :, :, None ] * do [:, :, None , :]
820770
821- block_dv = block_dv .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK , BLOCK_HEADDIM )
822- block_dv = tl .sum (block_dv , 0 )
771+ # block_dv = block_dv.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK, BLOCK_HEADDIM)
772+ block_dv = tl .sum (block_dv , 1 )
823773
824774 tl .atomic_add (block_dv_ptrs , block_dv , sem = 'relaxed' )
825775
826776 # get dp
827777
828- do_expanded = do .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
829- do_expanded = do_expanded .permute (1 , 0 , 2 )
830- do_expanded = tl .expand_dims (do_expanded , 2 )
831- do_expanded = tl .broadcast_to (do_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
832- do_expanded = do_expanded .reshape (BLOCK , 16 , BLOCK_HEADDIM )
778+ do = tl .expand_dims (do , 2 )
779+ do = tl .broadcast_to (do , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM ))
780+ do = do .reshape (BLOCK , 16 , BLOCK_HEADDIM )
833781
834782 block_v = tl .permute (block_v , (0 , 2 , 1 ))
835783
836- dp = tl .dot (do_expanded , block_v )
784+ dp = tl .dot (do , block_v )
837785
838786 dp = dp .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK )
839787 dp = tl .sum (dp , 2 ) / QUERY_EXPAND_DIM
840- dp = dp .permute (1 , 0 , 2 )
841- dp = dp .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
842788
843789 # ds
844790
845- ds = (p * (dp - Di [:, None ]) * softmax_scale )
791+ ds = (p * (dp - Di [:, :, None ]) * softmax_scale )
846792 ds = ds .to (q .dtype )
847793
848794 # block dk
849795
850- block_dk = ds [:, :, None ] * q [:, None , :]
851- block_dk = block_dk .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK , BLOCK_HEADDIM )
852- block_dk = tl .sum (block_dk , 0 )
796+ block_dk = ds [:, :, :, None ] * q [:, :, None , :]
797+ block_dk = tl .sum (block_dk , 1 )
853798
854799 tl .atomic_add (block_dk_ptrs , block_dk , sem = 'relaxed' )
855800
856801 # block dq
857802
858- ds_expanded = ds .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
859- ds_expanded = ds_expanded .permute (1 , 0 , 2 )
860- ds_expanded = tl .expand_dims (ds_expanded , 2 )
803+ ds_expanded = tl .expand_dims (ds , 2 )
861804 ds_expanded = tl .broadcast_to (ds_expanded , (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK ))
862805 ds_expanded = ds_expanded .reshape (BLOCK , 16 , BLOCK )
863806
864807 block_dq = tl .dot (ds_expanded , block_k )
865808
866809 block_dq = block_dq .reshape (BLOCK , QUERY_HEAD_GROUPS , QUERY_EXPAND_DIM , BLOCK_HEADDIM )
867810 block_dq = tl .sum (block_dq , 2 ) / QUERY_EXPAND_DIM
868- block_dq = block_dq .permute (1 , 0 , 2 )
869- block_dq = block_dq .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM )
870811
871812 dq += block_dq
872813
873814 # update dq
874815
875- dq = dq .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM )
816+ dq = dq .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
876817
877818 if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
878819 tl .atomic_add (dq_ptrs , dq , sem = 'relaxed' )
879820 else :
880821 if EVEN_HEADDIM :
881- tl .atomic_add (dq_ptrs , dq , mask = offs_m [None , :, None ] < seqlen_q , sem = 'relaxed' )
822+ tl .atomic_add (
823+ dq_ptrs ,
824+ dq ,
825+ mask = offs_m [:, None , None ] < seqlen_q ,
826+ sem = 'relaxed'
827+ )
882828 else :
883829 tl .atomic_add (
884830 dq_ptrs ,
885831 dq ,
886- mask = (offs_m [None , : , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
832+ mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim ),
887833 sem = 'relaxed' ,
888834 )
889835
0 commit comments