Skip to content

Commit 83fb0f7

Browse files
committed
Removes mask/bias from varlen; defaults nondet
Simplifies the varlen attention API by dropping explicit mask/bias inputs and associated gradients, reducing memory overhead and aligning with the underlying kernels. Avoids padding the key sequence length to multiples of 8 (still pads head size), relying on kernel support to handle ragged sizes and eliminating unnecessary work. Changes the default deterministic flag to False to favor performance; callers can still request deterministic behavior when needed. Updates saved tensors, sanitization, wrappers, returns, and docs to reflect the streamlined interface. Breaking change: callers must remove mask/bias arguments and any reliance on dbias gradients.
1 parent 2580b85 commit 83fb0f7

File tree

1 file changed

+11
-64
lines changed

1 file changed

+11
-64
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 11 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,6 @@ def _flash_dmattn_varlen_forward(
131131
q: torch.Tensor,
132132
k: torch.Tensor,
133133
v: torch.Tensor,
134-
mask: Optional[torch.Tensor],
135-
bias: Optional[torch.Tensor],
136134
cu_seqlens_q: torch.Tensor,
137135
cu_seqlens_k: torch.Tensor,
138136
max_seqlen_q: int,
@@ -146,13 +144,11 @@ def _flash_dmattn_varlen_forward(
146144
seqused_k: Optional[torch.Tensor] = None,
147145
zero_tensors: bool = False,
148146
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
149-
q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)]
147+
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
150148
out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd(
151149
q,
152150
k,
153151
v,
154-
mask,
155-
bias,
156152
None,
157153
cu_seqlens_q,
158154
cu_seqlens_k,
@@ -176,8 +172,6 @@ def _flash_dmattn_varlen_forward_fake(
176172
q: torch.Tensor,
177173
k: torch.Tensor,
178174
v: torch.Tensor,
179-
mask: Optional[torch.Tensor],
180-
bias: Optional[torch.Tensor],
181175
cu_seqlens_q: torch.Tensor,
182176
cu_seqlens_k: torch.Tensor,
183177
max_seqlen_q: int,
@@ -191,7 +185,7 @@ def _flash_dmattn_varlen_forward_fake(
191185
seqused_k: Optional[torch.Tensor] = None,
192186
zero_tensors: bool = False,
193187
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194-
q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)]
188+
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
195189
paged_kv = block_table is not None
196190
batch_size = cu_seqlens_q.numel() - 1
197191
total_q, num_heads, _ = q.shape
@@ -294,20 +288,17 @@ def _flash_dmattn_backward_fake(
294288
_wrapped_flash_dmattn_backward = _flash_dmattn_backward
295289

296290

297-
@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda")
291+
@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
298292
def _flash_dmattn_varlen_backward(
299293
dout: torch.Tensor,
300294
q: torch.Tensor,
301295
k: torch.Tensor,
302296
v: torch.Tensor,
303-
mask: Optional[torch.Tensor],
304-
bias: Optional[torch.Tensor],
305297
out: torch.Tensor,
306298
softmax_lse: torch.Tensor,
307299
dq: Optional[torch.Tensor],
308300
dk: Optional[torch.Tensor],
309301
dv: Optional[torch.Tensor],
310-
dbias: Optional[torch.Tensor],
311302
cu_seqlens_q: torch.Tensor,
312303
cu_seqlens_k: torch.Tensor,
313304
max_seqlen_q: int,
@@ -318,20 +309,17 @@ def _flash_dmattn_varlen_backward(
318309
deterministic: bool,
319310
zero_tensors: bool = False,
320311
) -> torch.Tensor:
321-
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
312+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
322313
(
323314
dq,
324315
dk,
325316
dv,
326-
dbias,
327317
softmax_d,
328318
) = flash_dmattn_gpu.varlen_bwd(
329319
dout,
330320
q,
331321
k,
332322
v,
333-
mask,
334-
bias,
335323
out,
336324
softmax_lse,
337325
dq,
@@ -347,7 +335,7 @@ def _flash_dmattn_varlen_backward(
347335
softcap,
348336
deterministic,
349337
)
350-
_sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0)
338+
_sanitize_tensors(dq, dk, dv, nan=0.0, posinf=0.0, neginf=0.0)
351339
return softmax_d
352340

353341

@@ -357,8 +345,6 @@ def _flash_dmattn_varlen_backward_fake(
357345
q: torch.Tensor,
358346
k: torch.Tensor,
359347
v: torch.Tensor,
360-
mask: Optional[torch.Tensor],
361-
bias: Optional[torch.Tensor],
362348
out: torch.Tensor,
363349
softmax_lse: torch.Tensor,
364350
dq: Optional[torch.Tensor],
@@ -375,7 +361,7 @@ def _flash_dmattn_varlen_backward_fake(
375361
deterministic: bool,
376362
zero_tensors: bool = False,
377363
) -> torch.Tensor:
378-
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
364+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
379365
batch_size = cu_seqlens_q.numel() - 1
380366
total_q, num_heads, _ = q.shape
381367

@@ -385,8 +371,6 @@ def _flash_dmattn_varlen_backward_fake(
385371
dk = torch.empty_like(k)
386372
if dv is None:
387373
dv = torch.empty_like(v)
388-
if dbias is None and bias is not None:
389-
dbias = torch.empty_like(bias)
390374
softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
391375

392376
return softmax_d
@@ -422,7 +406,7 @@ def forward(
422406
if softcap is None:
423407
softcap = 0.0
424408
if deterministic is None:
425-
deterministic = True
409+
deterministic = False
426410
if return_softmax is None:
427411
return_softmax = False
428412

@@ -521,8 +505,6 @@ def forward(
521505
q: torch.Tensor,
522506
k: torch.Tensor,
523507
v: torch.Tensor,
524-
mask: Optional[torch.Tensor],
525-
bias: Optional[torch.Tensor],
526508
cu_seqlens_q: torch.Tensor,
527509
cu_seqlens_k: torch.Tensor,
528510
max_seqlen_q: int,
@@ -545,7 +527,7 @@ def forward(
545527
if softcap is None:
546528
softcap = 0.0
547529
if deterministic is None:
548-
deterministic = True
530+
deterministic = False
549531
if return_softmax is None:
550532
return_softmax = False
551533

@@ -555,21 +537,11 @@ def forward(
555537
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
556538
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
557539
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
558-
seqlen_k_og = k.shape[1]
559-
if seqlen_k_og % 8 != 0:
560-
k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8])
561-
v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8])
562-
if mask is not None:
563-
mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False)
564-
if bias is not None:
565-
bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0)
566540

567541
out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward(
568542
q,
569543
k,
570544
v,
571-
mask,
572-
bias,
573545
cu_seqlens_q,
574546
cu_seqlens_k,
575547
max_seqlen_q,
@@ -583,7 +555,7 @@ def forward(
583555

584556
if is_grad:
585557
ctx.save_for_backward(
586-
q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
558+
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
587559
)
588560
ctx.max_seqlen_q = max_seqlen_q
589561
ctx.max_seqlen_k = max_seqlen_k
@@ -598,9 +570,8 @@ def forward(
598570

599571
@staticmethod
600572
def backward(ctx, dout, *args):
601-
q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
573+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
602574
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
603-
dbias = torch.zeros_like(bias).contiguous() if bias is not None else None
604575

605576
head_size_og = dout.size(2)
606577
dout_padded = dout
@@ -612,14 +583,11 @@ def backward(ctx, dout, *args):
612583
q,
613584
k,
614585
v,
615-
mask,
616-
bias,
617586
out,
618587
softmax_lse,
619588
dq,
620589
dk,
621590
dv,
622-
dbias,
623591
cu_seqlens_q,
624592
cu_seqlens_k,
625593
ctx.max_seqlen_q,
@@ -635,13 +603,7 @@ def backward(ctx, dout, *args):
635603
dk = dk[..., : dout.shape[-1]]
636604
dv = dv[..., : dout.shape[-1]]
637605

638-
if ctx.seqlen_k_og % 8 != 0:
639-
dk = dk[:, : ctx.seqlen_k_og, :, :]
640-
dv = dv[:, : ctx.seqlen_k_og, :, :]
641-
if dbias is not None:
642-
dbias = dbias[..., : ctx.seqlen_k_og]
643-
644-
return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None
606+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
645607

646608

647609
def flash_dmattn_func(
@@ -725,8 +687,6 @@ def flash_dmattn_varlen_func(
725687
query: torch.Tensor,
726688
key: torch.Tensor,
727689
value: torch.Tensor,
728-
attn_mask: Optional[torch.Tensor],
729-
attn_bias: Optional[torch.Tensor],
730690
cu_seqlens_q: torch.Tensor,
731691
cu_seqlens_k: torch.Tensor,
732692
max_seqlen_q: int,
@@ -744,11 +704,6 @@ def flash_dmattn_varlen_func(
744704
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
745705
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
746706
747-
Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA.
748-
For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension
749-
of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0
750-
of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias.
751-
752707
If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
753708
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
754709
1 1 1 1 0
@@ -765,12 +720,6 @@ def flash_dmattn_varlen_func(
765720
query: torch.Tensor. The query tensor of shape (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
766721
key: torch.Tensor. The key tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
767722
value: torch.Tensor. The value tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
768-
attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
769-
shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to apply to the attention scores.
770-
If None, no mask is applied.
771-
attn_bias: torch.Tensor, optional. The attention bias float tensor of
772-
shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to add to the attention scores.
773-
If None, no bias is applied.
774723
cu_seqlens_q: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into q.
775724
cu_seqlens_k: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into kv.
776725
max_seqlen_q: int. Maximum query sequence length in the batch.
@@ -796,8 +745,6 @@ def flash_dmattn_varlen_func(
796745
query,
797746
key,
798747
value,
799-
attn_mask,
800-
attn_bias,
801748
cu_seqlens_q,
802749
cu_seqlens_k,
803750
max_seqlen_q,

0 commit comments

Comments
 (0)