Skip to content

Commit 148506c

Browse files
[coloattention]modify coloattention (#5627)
* modify coloattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fxi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7ee569b commit 148506c

File tree

3 files changed

+16
-22
lines changed

3 files changed

+16
-22
lines changed

colossalai/kernel/kernel_loader.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,6 @@ class FlashAttentionLoader(KernelLoader):
113113
]
114114

115115

116-
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
117-
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
118-
119-
120116
class FlashAttentionWithCustomMaskLoader(KernelLoader):
121117
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
122118

colossalai/shardformer/layer/attn.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
FlashAttentionForFloatAndCustomMaskLoader,
99
FlashAttentionLoader,
1010
FlashAttentionWithCustomMaskLoader,
11-
FlashAttentionWithPaddingMaskLoader,
1211
KernelLoader,
1312
)
1413

@@ -65,15 +64,17 @@ def _init_kernels_dispatch():
6564
half_dispatch_map = {
6665
None: FlashAttentionLoader(),
6766
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
68-
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
67+
AttnMaskType.PADDED: FlashAttentionLoader(),
6968
AttnMaskType.CAUSAL: FlashAttentionLoader(),
70-
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
69+
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
7170
}
7271
# fp32
7372
float_dispatch_map = {
7473
None: FlashAttentionForFloatAndCustomMaskLoader(),
7574
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
75+
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
7676
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
77+
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
7778
}
7879
ColoAttention._kernel_dispatch_map = {
7980
torch.float16: half_dispatch_map,
@@ -140,16 +141,22 @@ def prepare_attn_kwargs(
140141
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
141142
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
142143
else:
144+
assert q_padding_mask.shape == (
145+
b,
146+
s_q,
147+
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
148+
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
143149
if kv_padding_mask is None:
144150
# self attention
145151
kv_padding_mask = q_padding_mask
146-
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
152+
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
153+
else:
154+
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
155+
assert kv_padding_mask.shape == (
147156
b,
148157
s_kv,
149-
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
150-
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
151-
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
152-
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
158+
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
159+
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
153160
outputs.update(
154161
{
155162
"cu_seqlens_q": cu_seqlens_q,

tests/test_shardformer/test_flash_attention.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
import torch
55
from torch.testing import assert_close
66

7-
from colossalai.kernel.kernel_loader import (
8-
FlashAttentionLoader,
9-
FlashAttentionWithCustomMaskLoader,
10-
FlashAttentionWithPaddingMaskLoader,
11-
)
7+
from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
128
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
139
from colossalai.shardformer.layer.attn import invert_mask
1410
from colossalai.testing import clear_cache_before_run, parameterize
@@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype):
119115
if ext.is_available():
120116
ext.assert_compatible()
121117
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
122-
for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
123-
ext = ext_cls()
124-
if ext.is_available():
125-
ext.assert_compatible()
126-
avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
127118

128119
test_sets = {
129120
"none": (lambda dtype: ({}, None), avail_attn_funcs),

0 commit comments

Comments
 (0)