|
8 | 8 | FlashAttentionForFloatAndCustomMaskLoader,
|
9 | 9 | FlashAttentionLoader,
|
10 | 10 | FlashAttentionWithCustomMaskLoader,
|
11 |
| - FlashAttentionWithPaddingMaskLoader, |
12 | 11 | KernelLoader,
|
13 | 12 | )
|
14 | 13 |
|
@@ -65,15 +64,17 @@ def _init_kernels_dispatch():
|
65 | 64 | half_dispatch_map = {
|
66 | 65 | None: FlashAttentionLoader(),
|
67 | 66 | AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
|
68 |
| - AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(), |
| 67 | + AttnMaskType.PADDED: FlashAttentionLoader(), |
69 | 68 | AttnMaskType.CAUSAL: FlashAttentionLoader(),
|
70 |
| - AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(), |
| 69 | + AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(), |
71 | 70 | }
|
72 | 71 | # fp32
|
73 | 72 | float_dispatch_map = {
|
74 | 73 | None: FlashAttentionForFloatAndCustomMaskLoader(),
|
75 | 74 | AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
|
| 75 | + AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(), |
76 | 76 | AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
| 77 | + AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), |
77 | 78 | }
|
78 | 79 | ColoAttention._kernel_dispatch_map = {
|
79 | 80 | torch.float16: half_dispatch_map,
|
@@ -140,16 +141,22 @@ def prepare_attn_kwargs(
|
140 | 141 | outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
141 | 142 | attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
|
142 | 143 | else:
|
| 144 | + assert q_padding_mask.shape == ( |
| 145 | + b, |
| 146 | + s_q, |
| 147 | + ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})" |
| 148 | + max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) |
143 | 149 | if kv_padding_mask is None:
|
144 | 150 | # self attention
|
145 | 151 | kv_padding_mask = q_padding_mask
|
146 |
| - assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == ( |
| 152 | + max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices |
| 153 | + else: |
| 154 | + max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) |
| 155 | + assert kv_padding_mask.shape == ( |
147 | 156 | b,
|
148 | 157 | s_kv,
|
149 |
| - ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" |
150 |
| - attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) |
151 |
| - max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) |
152 |
| - max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) |
| 158 | + ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" |
| 159 | + attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device) |
153 | 160 | outputs.update(
|
154 | 161 | {
|
155 | 162 | "cu_seqlens_q": cu_seqlens_q,
|
|
0 commit comments