Skip to content

Commit c2fe313

Browse files
authored
[hotfix] fix flash attn window_size err (#6132)
* [fix] fix flash attn * [hotfix] fix flash-atten version * [fix] fix flash_atten version * [fix] fix flash-atten versions * [fix] fix flash-attn not enough values to unpack error * [fix] fix test_ring_attn * [fix] fix test ring attn
1 parent a259651 commit c2fe313

File tree

1 file changed

+41
-22
lines changed
  • colossalai/shardformer/layer

1 file changed

+41
-22
lines changed

colossalai/shardformer/layer/attn.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.distributed as dist
77
import torch.nn.functional as F
88
from einops import rearrange
9+
from packaging import version
910

1011
from colossalai.kernel.kernel_loader import (
1112
FlashAttentionDaoLoader,
@@ -642,16 +643,21 @@ def forward(
642643
max_seqlen_q = max_seqlen_kv = max_seqlen
643644
cu_seqlens_half = cu_seqlens // 2
644645
max_seqlen_half = max_seqlen // 2
645-
646646
misc_kwargs = {
647-
"window_size": (-1, -1),
648647
"alibi_slopes": None,
649648
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
650649
"dropout_p": dropout_p,
651650
"block_table": None,
652651
"softcap": 0.0,
653652
"return_softmax": False,
654653
}
654+
import flash_attn
655+
656+
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
657+
misc_kwargs["window_size_left"] = -1
658+
misc_kwargs["window_size_right"] = -1
659+
else:
660+
misc_kwargs["window_size"] = (-1, -1)
655661

656662
if (
657663
RingAttention.HALF_INDICES is not None
@@ -707,26 +713,39 @@ def forward(
707713

708714
# Helper to pass args to FA
709715
def _forward(q, k, v, causal):
710-
(
711-
_,
712-
_,
713-
_,
714-
_,
715-
out,
716-
softmax_lse,
717-
_,
718-
rng_state,
719-
) = _flash_attn_forward(
720-
q,
721-
k,
722-
v,
723-
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
724-
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
725-
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
726-
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
727-
causal=causal,
728-
**misc_kwargs,
729-
)
716+
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
717+
(out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward(
718+
q,
719+
k,
720+
v,
721+
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
722+
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
723+
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
724+
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
725+
causal=causal,
726+
**misc_kwargs,
727+
)
728+
else:
729+
(
730+
_,
731+
_,
732+
_,
733+
_,
734+
out,
735+
softmax_lse,
736+
_,
737+
rng_state,
738+
) = _flash_attn_forward(
739+
q,
740+
k,
741+
v,
742+
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
743+
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
744+
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
745+
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
746+
causal=causal,
747+
**misc_kwargs,
748+
)
730749
return out, softmax_lse, rng_state
731750

732751
def _kv_comm(i):

0 commit comments

Comments
 (0)