@@ -411,8 +411,6 @@ def _bwd_kernel_one_col_block(
411411 # Initialize dv and dk
412412 dv = tl .zeros ([BLOCK_N , BLOCK_HEADDIM ], dtype = tl .float32 )
413413 dk = tl .zeros ([BLOCK_N , BLOCK_HEADDIM ], dtype = tl .float32 )
414- # Initialize softmax unscale factor
415- softmax_unscale = 1.0 / softmax_scale
416414 # There seems to be some problem with Triton pipelining that makes results wrong for
417415 # headdim=64, seqlen=(113, 255). In this case the for loop may have zero step,
418416 # and pipelining with the bias matrix could screw it up. So we just exit early.
@@ -549,41 +547,38 @@ def _bwd_kernel_one_col_block(
549547 # Putting the subtraction after the dp matmul (instead of before) is slightly faster
550548 Di = tl .load (D + offs_m_curr )
551549
552- # Compute dbias
553- dbias = (p * (dp - Di [:, None ])).to (q .dtype )
554-
550+ # Compute ds
551+ # Converting ds to q.dtype here reduces register pressure and makes it much faster
552+ # for BLOCK_HEADDIM=128
553+ ds = (p * (dp - Di [:, None ])).to (q .dtype )
554+
555555 # Write back
556556 if not (EVEN_M & EVEN_N ):
557557 tl .debug_barrier ()
558558 if HAS_BIAS :
559559 if ACCUM_DBIAS :
560- acc_dbias += tl .sum (dbias , axis = 0 )
560+ acc_dbias += tl .sum (ds , axis = 0 )
561561 else :
562562 if EVEN_M & EVEN_N :
563563 tl .store (
564564 db_ptrs ,
565- dbias ,
565+ ds ,
566566 )
567567 else :
568568 tl .store (
569569 db_ptrs ,
570- dbias ,
570+ ds ,
571571 mask = (offs_m_curr [:, None ] < seqlen_q ) & (offs_n [None , :] < seqlen_k ),
572572 )
573573
574- # Compute ds
575- # Converting ds to q.dtype here reduces register pressure and makes it much faster
576- # for BLOCK_HEADDIM=128
577- ds = (dbias * softmax_scale ).to (q .dtype )
578-
579574 # Compute dk
580575 dk += tl .dot (tl .trans (ds ), q )
581576
582577 # Compute dq
583578 if not ATOMIC_ADD :
584579 if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
585580 dq = tl .load (dq_ptrs , eviction_policy = "evict_last" )
586- dq += tl .dot (ds , ( k * softmax_unscale ).to (ds .dtype ) )
581+ dq += tl .dot (ds , k ).to (ds .dtype )
587582 tl .store (dq_ptrs , dq , eviction_policy = "evict_last" )
588583 else :
589584 if EVEN_HEADDIM :
@@ -593,7 +588,7 @@ def _bwd_kernel_one_col_block(
593588 other = 0.0 ,
594589 eviction_policy = "evict_last" ,
595590 )
596- dq += tl .dot (ds , ( k * softmax_unscale ).to (ds .dtype ) )
591+ dq += tl .dot (ds , k ).to (ds .dtype )
597592 tl .store (
598593 dq_ptrs ,
599594 dq ,
@@ -607,15 +602,15 @@ def _bwd_kernel_one_col_block(
607602 other = 0.0 ,
608603 eviction_policy = "evict_last" ,
609604 )
610- dq += tl .dot (ds , ( k * softmax_unscale ).to (ds .dtype ) )
605+ dq += tl .dot (ds , k ).to (ds .dtype )
611606 tl .store (
612607 dq_ptrs ,
613608 dq ,
614609 mask = (offs_m_curr [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
615610 eviction_policy = "evict_last" ,
616611 )
617612 else : # If we're parallelizing across the seqlen_k dimension
618- dq = tl .dot (ds , ( k * softmax_unscale ).to (ds .dtype ) )
613+ dq = tl .dot (ds , k ).to (ds .dtype )
619614 if EVEN_M & EVEN_HEADDIM : # Race condition if we just do EVEN_M
620615 tl .atomic_add (dq_ptrs , dq )
621616 else :
@@ -638,7 +633,10 @@ def _bwd_kernel_one_col_block(
638633 m_ptrs += BLOCK_M * stride_mm
639634 if HAS_BIAS :
640635 b_ptrs += BLOCK_M * stride_bm
641-
636+
637+ # Scale dk
638+ dk = (dk * softmax_scale ).to (dk .dtype )
639+
642640 # Write back
643641 dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
644642 dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
0 commit comments