Skip to content

Commit 53c34fa

Browse files
committed
Rebalances backward scaling
Moves softmax scaling to the final dk update to cut register pressure and simplify accumulation. Aligns dq accumulation with unscaled k for more stable gradients.
1 parent cd60a83 commit 53c34fa

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

flash_dmattn/flash_dmattn_triton.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,6 @@ def _bwd_kernel_one_col_block(
411411
# Initialize dv and dk
412412
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
413413
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
414-
# Initialize softmax unscale factor
415-
softmax_unscale = 1.0 / softmax_scale
416414
# There seems to be some problem with Triton pipelining that makes results wrong for
417415
# headdim=64, seqlen=(113, 255). In this case the for loop may have zero step,
418416
# and pipelining with the bias matrix could screw it up. So we just exit early.
@@ -549,41 +547,38 @@ def _bwd_kernel_one_col_block(
549547
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
550548
Di = tl.load(D + offs_m_curr)
551549

552-
# Compute dbias
553-
dbias = (p * (dp - Di[:, None])).to(q.dtype)
554-
550+
# Compute ds
551+
# Converting ds to q.dtype here reduces register pressure and makes it much faster
552+
# for BLOCK_HEADDIM=128
553+
ds = (p * (dp - Di[:, None])).to(q.dtype)
554+
555555
# Write back
556556
if not (EVEN_M & EVEN_N):
557557
tl.debug_barrier()
558558
if HAS_BIAS:
559559
if ACCUM_DBIAS:
560-
acc_dbias += tl.sum(dbias, axis=0)
560+
acc_dbias += tl.sum(ds, axis=0)
561561
else:
562562
if EVEN_M & EVEN_N:
563563
tl.store(
564564
db_ptrs,
565-
dbias,
565+
ds,
566566
)
567567
else:
568568
tl.store(
569569
db_ptrs,
570-
dbias,
570+
ds,
571571
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
572572
)
573573

574-
# Compute ds
575-
# Converting ds to q.dtype here reduces register pressure and makes it much faster
576-
# for BLOCK_HEADDIM=128
577-
ds = (dbias * softmax_scale).to(q.dtype)
578-
579574
# Compute dk
580575
dk += tl.dot(tl.trans(ds), q)
581576

582577
# Compute dq
583578
if not ATOMIC_ADD:
584579
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
585580
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
586-
dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype))
581+
dq += tl.dot(ds, k).to(ds.dtype)
587582
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
588583
else:
589584
if EVEN_HEADDIM:
@@ -593,7 +588,7 @@ def _bwd_kernel_one_col_block(
593588
other=0.0,
594589
eviction_policy="evict_last",
595590
)
596-
dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype))
591+
dq += tl.dot(ds, k).to(ds.dtype)
597592
tl.store(
598593
dq_ptrs,
599594
dq,
@@ -607,15 +602,15 @@ def _bwd_kernel_one_col_block(
607602
other=0.0,
608603
eviction_policy="evict_last",
609604
)
610-
dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype))
605+
dq += tl.dot(ds, k).to(ds.dtype)
611606
tl.store(
612607
dq_ptrs,
613608
dq,
614609
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
615610
eviction_policy="evict_last",
616611
)
617612
else: # If we're parallelizing across the seqlen_k dimension
618-
dq = tl.dot(ds, (k * softmax_unscale).to(ds.dtype))
613+
dq = tl.dot(ds, k).to(ds.dtype)
619614
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
620615
tl.atomic_add(dq_ptrs, dq)
621616
else:
@@ -638,7 +633,10 @@ def _bwd_kernel_one_col_block(
638633
m_ptrs += BLOCK_M * stride_mm
639634
if HAS_BIAS:
640635
b_ptrs += BLOCK_M * stride_bm
641-
636+
637+
# Scale dk
638+
dk = (dk * softmax_scale).to(dk.dtype)
639+
642640
# Write back
643641
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
644642
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])

0 commit comments

Comments
 (0)