Skip to content

Commit 502a1a4

Browse files
committed
Enhances FlashDMAttnFunc to track original sequence length bias and adjusts dbias computation based on its value
1 parent 8effe3c commit 502a1a4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def forward(
409409
deterministic = False
410410
if return_softmax is None:
411411
return_softmax = False
412+
seqlen_k_bias_og = bias.shape[-1] if bias is not None else 0
412413

413414
# Padding to multiple of 8 for 16-bit memory allocations
414415
head_size_og = q.size(3)
@@ -446,6 +447,7 @@ def forward(
446447
ctx.is_causal = is_causal
447448
ctx.softcap = softcap
448449
ctx.deterministic = deterministic
450+
ctx.seqlen_k_bias_og = seqlen_k_bias_og
449451

450452
out = out_padded[..., :head_size_og]
451453

@@ -491,7 +493,7 @@ def backward(
491493
dv = dv[..., : dout.shape[-1]]
492494

493495
if dbias is not None:
494-
dbias = dbias[..., : k.shape[1]]
496+
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]]
495497

496498
return dq, dk, dv, None, dbias, None, None, None, None, None, None
497499

0 commit comments

Comments
 (0)