Skip to content

Commit 755e0a4

Browse files
committed
Aligns mask/bias to 128 and supports broadcast
Updates attention mask/bias handling to round the key length to a multiple of 128 and expand length-1 tensors or pad as needed, preventing shape mismatches and reducing unnecessary padding of K/V. Simplifies backward by slicing only the bias gradient to the original key length and removing tracking of the original sequence length. Clarifies docs to allow broadcastable dimensions for mask/bias across batch, heads, and sequence.
1 parent 2a8f9ea commit 755e0a4

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -416,14 +416,17 @@ def forward(
416416
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
417417
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
418418
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
419-
seqlen_k_og = k.shape[1]
420-
if seqlen_k_og % 8 != 0:
421-
k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8])
422-
v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8])
423-
if mask is not None:
424-
mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False)
425-
if bias is not None:
426-
bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0)
419+
seqlen_k_rounded = round_multiple(k.shape[1], 128)
420+
if mask is not None and mask.shape[-1] != seqlen_k_rounded:
421+
if mask.shape[-1] == 1:
422+
mask = mask.expand(*mask.shape[:-1], seqlen_k_rounded)
423+
else:
424+
mask = torch.nn.functional.pad(mask, [0, seqlen_k_rounded - mask.shape[-1]])
425+
if bias is not None and bias.shape[-1] != seqlen_k_rounded:
426+
if bias.shape[-1] == 1:
427+
bias = bias.expand(*bias.shape[:-1], seqlen_k_rounded)
428+
else:
429+
bias = torch.nn.functional.pad(bias, [0, seqlen_k_rounded - bias.shape[-1]])
427430

428431
out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward(
429432
q,
@@ -443,7 +446,6 @@ def forward(
443446
ctx.is_causal = is_causal
444447
ctx.softcap = softcap
445448
ctx.deterministic = deterministic
446-
ctx.seqlen_k_og = seqlen_k_og
447449

448450
out = out_padded[..., :head_size_og]
449451

@@ -488,11 +490,8 @@ def backward(
488490
dk = dk[..., : dout.shape[-1]]
489491
dv = dv[..., : dout.shape[-1]]
490492

491-
if ctx.seqlen_k_og % 8 != 0:
492-
dk = dk[:, : ctx.seqlen_k_og, :, :]
493-
dv = dv[:, : ctx.seqlen_k_og, :, :]
494-
if dbias is not None:
495-
dbias = dbias[..., : ctx.seqlen_k_og]
493+
if dbias is not None:
494+
dbias = dbias[..., : k.shape[1]]
496495

497496
return dq, dk, dv, None, dbias, None, None, None, None, None, None
498497

@@ -646,10 +645,10 @@ def flash_dmattn_func(
646645
key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim)
647646
value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim)
648647
attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
649-
shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to apply to the attention scores.
648+
shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to apply to the attention scores.
650649
If None, no mask is applied.
651650
attn_bias: torch.Tensor, optional. The attention bias float tensor of
652-
shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to add to the attention scores.
651+
shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to add to the attention scores.
653652
If None, no bias is applied.
654653
softmax_scale: float. The scaling of QK^T before applying softmax.
655654
Default to 1 / sqrt(headdim).

0 commit comments

Comments
 (0)