@@ -131,8 +131,6 @@ def _flash_dmattn_varlen_forward(
131131 q : torch .Tensor ,
132132 k : torch .Tensor ,
133133 v : torch .Tensor ,
134- mask : Optional [torch .Tensor ],
135- bias : Optional [torch .Tensor ],
136134 cu_seqlens_q : torch .Tensor ,
137135 cu_seqlens_k : torch .Tensor ,
138136 max_seqlen_q : int ,
@@ -146,13 +144,11 @@ def _flash_dmattn_varlen_forward(
146144 seqused_k : Optional [torch .Tensor ] = None ,
147145 zero_tensors : bool = False ,
148146) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
149- q , k , v , mask , bias = [maybe_contiguous (x ) for x in (q , k , v , mask , bias )]
147+ q , k , v = [maybe_contiguous (x ) for x in (q , k , v )]
150148 out , softmax_lse , S_dmask = flash_dmattn_gpu .varlen_fwd (
151149 q ,
152150 k ,
153151 v ,
154- mask ,
155- bias ,
156152 None ,
157153 cu_seqlens_q ,
158154 cu_seqlens_k ,
@@ -176,8 +172,6 @@ def _flash_dmattn_varlen_forward_fake(
176172 q : torch .Tensor ,
177173 k : torch .Tensor ,
178174 v : torch .Tensor ,
179- mask : Optional [torch .Tensor ],
180- bias : Optional [torch .Tensor ],
181175 cu_seqlens_q : torch .Tensor ,
182176 cu_seqlens_k : torch .Tensor ,
183177 max_seqlen_q : int ,
@@ -191,7 +185,7 @@ def _flash_dmattn_varlen_forward_fake(
191185 seqused_k : Optional [torch .Tensor ] = None ,
192186 zero_tensors : bool = False ,
193187) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
194- q , k , v , mask , bias = [maybe_contiguous (x ) for x in (q , k , v , mask , bias )]
188+ q , k , v = [maybe_contiguous (x ) for x in (q , k , v )]
195189 paged_kv = block_table is not None
196190 batch_size = cu_seqlens_q .numel () - 1
197191 total_q , num_heads , _ = q .shape
@@ -294,20 +288,17 @@ def _flash_dmattn_backward_fake(
294288_wrapped_flash_dmattn_backward = _flash_dmattn_backward
295289
296290
297- @_torch_custom_op_wrapper ("flash_dmattn::_flash_dmattn_varlen_backward" , mutates_args = ("dq" , "dk" , "dv" , "dbias" ), device_types = "cuda" )
291+ @_torch_custom_op_wrapper ("flash_dmattn::_flash_dmattn_varlen_backward" , mutates_args = ("dq" , "dk" , "dv" ), device_types = "cuda" )
298292def _flash_dmattn_varlen_backward (
299293 dout : torch .Tensor ,
300294 q : torch .Tensor ,
301295 k : torch .Tensor ,
302296 v : torch .Tensor ,
303- mask : Optional [torch .Tensor ],
304- bias : Optional [torch .Tensor ],
305297 out : torch .Tensor ,
306298 softmax_lse : torch .Tensor ,
307299 dq : Optional [torch .Tensor ],
308300 dk : Optional [torch .Tensor ],
309301 dv : Optional [torch .Tensor ],
310- dbias : Optional [torch .Tensor ],
311302 cu_seqlens_q : torch .Tensor ,
312303 cu_seqlens_k : torch .Tensor ,
313304 max_seqlen_q : int ,
@@ -318,20 +309,17 @@ def _flash_dmattn_varlen_backward(
318309 deterministic : bool ,
319310 zero_tensors : bool = False ,
320311) -> torch .Tensor :
321- dout , q , k , v , mask , bias , out = [maybe_contiguous (x ) for x in (dout , q , k , v , mask , bias , out )]
312+ dout , q , k , v , out = [maybe_contiguous (x ) for x in (dout , q , k , v , out )]
322313 (
323314 dq ,
324315 dk ,
325316 dv ,
326- dbias ,
327317 softmax_d ,
328318 ) = flash_dmattn_gpu .varlen_bwd (
329319 dout ,
330320 q ,
331321 k ,
332322 v ,
333- mask ,
334- bias ,
335323 out ,
336324 softmax_lse ,
337325 dq ,
@@ -347,7 +335,7 @@ def _flash_dmattn_varlen_backward(
347335 softcap ,
348336 deterministic ,
349337 )
350- _sanitize_tensors (dq , dk , dv , dbias , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
338+ _sanitize_tensors (dq , dk , dv , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
351339 return softmax_d
352340
353341
@@ -357,8 +345,6 @@ def _flash_dmattn_varlen_backward_fake(
357345 q : torch .Tensor ,
358346 k : torch .Tensor ,
359347 v : torch .Tensor ,
360- mask : Optional [torch .Tensor ],
361- bias : Optional [torch .Tensor ],
362348 out : torch .Tensor ,
363349 softmax_lse : torch .Tensor ,
364350 dq : Optional [torch .Tensor ],
@@ -375,7 +361,7 @@ def _flash_dmattn_varlen_backward_fake(
375361 deterministic : bool ,
376362 zero_tensors : bool = False ,
377363) -> torch .Tensor :
378- dout , q , k , v , mask , bias , out = [maybe_contiguous (x ) for x in (dout , q , k , v , mask , bias , out )]
364+ dout , q , k , v , out = [maybe_contiguous (x ) for x in (dout , q , k , v , out )]
379365 batch_size = cu_seqlens_q .numel () - 1
380366 total_q , num_heads , _ = q .shape
381367
@@ -385,8 +371,6 @@ def _flash_dmattn_varlen_backward_fake(
385371 dk = torch .empty_like (k )
386372 if dv is None :
387373 dv = torch .empty_like (v )
388- if dbias is None and bias is not None :
389- dbias = torch .empty_like (bias )
390374 softmax_d = torch .empty ((num_heads , total_q + 128 * batch_size ), device = q .device , dtype = torch .float32 )
391375
392376 return softmax_d
@@ -422,7 +406,7 @@ def forward(
422406 if softcap is None :
423407 softcap = 0.0
424408 if deterministic is None :
425- deterministic = True
409+ deterministic = False
426410 if return_softmax is None :
427411 return_softmax = False
428412
@@ -521,8 +505,6 @@ def forward(
521505 q : torch .Tensor ,
522506 k : torch .Tensor ,
523507 v : torch .Tensor ,
524- mask : Optional [torch .Tensor ],
525- bias : Optional [torch .Tensor ],
526508 cu_seqlens_q : torch .Tensor ,
527509 cu_seqlens_k : torch .Tensor ,
528510 max_seqlen_q : int ,
@@ -545,7 +527,7 @@ def forward(
545527 if softcap is None :
546528 softcap = 0.0
547529 if deterministic is None :
548- deterministic = True
530+ deterministic = False
549531 if return_softmax is None :
550532 return_softmax = False
551533
@@ -555,21 +537,11 @@ def forward(
555537 q = torch .nn .functional .pad (q , [0 , 8 - head_size_og % 8 ])
556538 k = torch .nn .functional .pad (k , [0 , 8 - head_size_og % 8 ])
557539 v = torch .nn .functional .pad (v , [0 , 8 - head_size_og % 8 ])
558- seqlen_k_og = k .shape [1 ]
559- if seqlen_k_og % 8 != 0 :
560- k = torch .nn .functional .pad (k , [0 , 0 , 0 , 0 , 0 , 8 - seqlen_k_og % 8 ])
561- v = torch .nn .functional .pad (v , [0 , 0 , 0 , 0 , 0 , 8 - seqlen_k_og % 8 ])
562- if mask is not None :
563- mask = torch .nn .functional .pad (mask , [0 , 8 - seqlen_k_og % 8 ], value = False )
564- if bias is not None :
565- bias = torch .nn .functional .pad (bias , [0 , 8 - seqlen_k_og % 8 ], value = 0.0 )
566540
567541 out_padded , softmax_lse , S_dmask = _wrapped_flash_dmattn_varlen_forward (
568542 q ,
569543 k ,
570544 v ,
571- mask ,
572- bias ,
573545 cu_seqlens_q ,
574546 cu_seqlens_k ,
575547 max_seqlen_q ,
@@ -583,7 +555,7 @@ def forward(
583555
584556 if is_grad :
585557 ctx .save_for_backward (
586- q , k , v , mask , bias , out_padded , softmax_lse , cu_seqlens_q , cu_seqlens_k
558+ q , k , v , out_padded , softmax_lse , cu_seqlens_q , cu_seqlens_k
587559 )
588560 ctx .max_seqlen_q = max_seqlen_q
589561 ctx .max_seqlen_k = max_seqlen_k
@@ -598,9 +570,8 @@ def forward(
598570
599571 @staticmethod
600572 def backward (ctx , dout , * args ):
601- q , k , v , mask , bias , out , softmax_lse , cu_seqlens_q , cu_seqlens_k = ctx .saved_tensors
573+ q , k , v , out , softmax_lse , cu_seqlens_q , cu_seqlens_k = ctx .saved_tensors
602574 dq , dk , dv = torch .empty_like (q ), torch .empty_like (k ), torch .empty_like (v )
603- dbias = torch .zeros_like (bias ).contiguous () if bias is not None else None
604575
605576 head_size_og = dout .size (2 )
606577 dout_padded = dout
@@ -612,14 +583,11 @@ def backward(ctx, dout, *args):
612583 q ,
613584 k ,
614585 v ,
615- mask ,
616- bias ,
617586 out ,
618587 softmax_lse ,
619588 dq ,
620589 dk ,
621590 dv ,
622- dbias ,
623591 cu_seqlens_q ,
624592 cu_seqlens_k ,
625593 ctx .max_seqlen_q ,
@@ -635,13 +603,7 @@ def backward(ctx, dout, *args):
635603 dk = dk [..., : dout .shape [- 1 ]]
636604 dv = dv [..., : dout .shape [- 1 ]]
637605
638- if ctx .seqlen_k_og % 8 != 0 :
639- dk = dk [:, : ctx .seqlen_k_og , :, :]
640- dv = dv [:, : ctx .seqlen_k_og , :, :]
641- if dbias is not None :
642- dbias = dbias [..., : ctx .seqlen_k_og ]
643-
644- return dq , dk , dv , None , dbias , None , None , None , None , None , None , None , None , None , None , None
606+ return dq , dk , dv , None , None , None , None , None , None , None , None , None , None , None
645607
646608
647609def flash_dmattn_func (
@@ -725,8 +687,6 @@ def flash_dmattn_varlen_func(
725687 query : torch .Tensor ,
726688 key : torch .Tensor ,
727689 value : torch .Tensor ,
728- attn_mask : Optional [torch .Tensor ],
729- attn_bias : Optional [torch .Tensor ],
730690 cu_seqlens_q : torch .Tensor ,
731691 cu_seqlens_k : torch .Tensor ,
732692 max_seqlen_q : int ,
@@ -744,11 +704,6 @@ def flash_dmattn_varlen_func(
744704 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
745705 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
746706
747- Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA.
748- For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension
749- of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0
750- of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias.
751-
752707 If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
753708 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
754709 1 1 1 1 0
@@ -765,12 +720,6 @@ def flash_dmattn_varlen_func(
765720 query: torch.Tensor. The query tensor of shape (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
766721 key: torch.Tensor. The key tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
767722 value: torch.Tensor. The value tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
768- attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
769- shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to apply to the attention scores.
770- If None, no mask is applied.
771- attn_bias: torch.Tensor, optional. The attention bias float tensor of
772- shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to add to the attention scores.
773- If None, no bias is applied.
774723 cu_seqlens_q: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into q.
775724 cu_seqlens_k: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into kv.
776725 max_seqlen_q: int. Maximum query sequence length in the batch.
@@ -796,8 +745,6 @@ def flash_dmattn_varlen_func(
796745 query ,
797746 key ,
798747 value ,
799- attn_mask ,
800- attn_bias ,
801748 cu_seqlens_q ,
802749 cu_seqlens_k ,
803750 max_seqlen_q ,
0 commit comments