|
6 | 6 | import torch.distributed as dist
|
7 | 7 | import torch.nn.functional as F
|
8 | 8 | from einops import rearrange
|
| 9 | +from packaging import version |
9 | 10 |
|
10 | 11 | from colossalai.kernel.kernel_loader import (
|
11 | 12 | FlashAttentionDaoLoader,
|
@@ -642,16 +643,21 @@ def forward(
|
642 | 643 | max_seqlen_q = max_seqlen_kv = max_seqlen
|
643 | 644 | cu_seqlens_half = cu_seqlens // 2
|
644 | 645 | max_seqlen_half = max_seqlen // 2
|
645 |
| - |
646 | 646 | misc_kwargs = {
|
647 |
| - "window_size": (-1, -1), |
648 | 647 | "alibi_slopes": None,
|
649 | 648 | "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
650 | 649 | "dropout_p": dropout_p,
|
651 | 650 | "block_table": None,
|
652 | 651 | "softcap": 0.0,
|
653 | 652 | "return_softmax": False,
|
654 | 653 | }
|
| 654 | + import flash_attn |
| 655 | + |
| 656 | + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): |
| 657 | + misc_kwargs["window_size_left"] = -1 |
| 658 | + misc_kwargs["window_size_right"] = -1 |
| 659 | + else: |
| 660 | + misc_kwargs["window_size"] = (-1, -1) |
655 | 661 |
|
656 | 662 | if (
|
657 | 663 | RingAttention.HALF_INDICES is not None
|
@@ -707,26 +713,39 @@ def forward(
|
707 | 713 |
|
708 | 714 | # Helper to pass args to FA
|
709 | 715 | def _forward(q, k, v, causal):
|
710 |
| - ( |
711 |
| - _, |
712 |
| - _, |
713 |
| - _, |
714 |
| - _, |
715 |
| - out, |
716 |
| - softmax_lse, |
717 |
| - _, |
718 |
| - rng_state, |
719 |
| - ) = _flash_attn_forward( |
720 |
| - q, |
721 |
| - k, |
722 |
| - v, |
723 |
| - cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, |
724 |
| - cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, |
725 |
| - max_seqlen_q if q.shape[0] == t else max_seqlen_half, |
726 |
| - max_seqlen_kv if k.shape[0] == t else max_seqlen_half, |
727 |
| - causal=causal, |
728 |
| - **misc_kwargs, |
729 |
| - ) |
| 716 | + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): |
| 717 | + (out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward( |
| 718 | + q, |
| 719 | + k, |
| 720 | + v, |
| 721 | + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, |
| 722 | + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, |
| 723 | + max_seqlen_q if q.shape[0] == t else max_seqlen_half, |
| 724 | + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, |
| 725 | + causal=causal, |
| 726 | + **misc_kwargs, |
| 727 | + ) |
| 728 | + else: |
| 729 | + ( |
| 730 | + _, |
| 731 | + _, |
| 732 | + _, |
| 733 | + _, |
| 734 | + out, |
| 735 | + softmax_lse, |
| 736 | + _, |
| 737 | + rng_state, |
| 738 | + ) = _flash_attn_forward( |
| 739 | + q, |
| 740 | + k, |
| 741 | + v, |
| 742 | + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, |
| 743 | + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, |
| 744 | + max_seqlen_q if q.shape[0] == t else max_seqlen_half, |
| 745 | + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, |
| 746 | + causal=causal, |
| 747 | + **misc_kwargs, |
| 748 | + ) |
730 | 749 | return out, softmax_lse, rng_state
|
731 | 750 |
|
732 | 751 | def _kv_comm(i):
|
|
0 commit comments