|
79 | 79 | if _CAN_USE_FLASH_ATTN_3: |
80 | 80 | from flash_attn_interface import flash_attn_func as flash_attn_3_func |
81 | 81 | from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func |
82 | | - from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward |
83 | 82 | else: |
84 | 83 | flash_attn_3_func = None |
85 | 84 | flash_attn_3_varlen_func = None |
86 | | - flash_attn_3_forward = None |
87 | 85 |
|
88 | 86 | if _CAN_USE_AITER_ATTN: |
89 | 87 | from aiter import flash_attn_func as aiter_flash_attn_func |
@@ -623,42 +621,22 @@ def _wrapped_flash_attn_3( |
623 | 621 | ) -> Tuple[torch.Tensor, torch.Tensor]: |
624 | 622 | # Hardcoded for now because pytorch does not support tuple/int type hints |
625 | 623 | window_size = (-1, -1) |
626 | | - max_seqlen_q = q.shape[2] |
627 | | - max_seqlen_k = k.shape[2] |
628 | | - |
629 | | - out, lse, *_ = flash_attn_3_forward( |
| 624 | + out, lse, *_ = flash_attn_3_func( |
630 | 625 | q=q, |
631 | 626 | k=k, |
632 | 627 | v=v, |
633 | | - k_new=None, |
634 | | - v_new=None, |
| 628 | + softmax_scale=softmax_scale, |
| 629 | + causal=causal, |
635 | 630 | qv=qv, |
636 | | - out=None, |
637 | | - cu_seqlens_q=None, |
638 | | - cu_seqlens_k=None, |
639 | | - cu_seqlens_k_new=None, |
640 | | - seqused_q=None, |
641 | | - seqused_k=None, |
642 | | - max_seqlen_q=max_seqlen_q, |
643 | | - max_seqlen_k=max_seqlen_k, |
644 | | - page_table=None, |
645 | | - kv_batch_idx=None, |
646 | | - leftpad_k=None, |
647 | | - rotary_cos=None, |
648 | | - rotary_sin=None, |
649 | | - seqlens_rotary=None, |
650 | 631 | q_descale=q_descale, |
651 | 632 | k_descale=k_descale, |
652 | 633 | v_descale=v_descale, |
653 | | - softmax_scale=softmax_scale, |
654 | | - causal=causal, |
655 | 634 | window_size=window_size, |
656 | 635 | attention_chunk=attention_chunk, |
657 | 636 | softcap=softcap, |
658 | | - rotary_interleaved=True, |
659 | | - scheduler_metadata=None, |
660 | 637 | num_splits=num_splits, |
661 | 638 | pack_gqa=pack_gqa, |
| 639 | + deterministic=deterministic, |
662 | 640 | sm_margin=sm_margin, |
663 | 641 | ) |
664 | 642 | lse = lse.permute(0, 2, 1) |
|
0 commit comments