Skip to content

Commit feee72b

Browse files
committed
Add flash decoding with alibi triton op
Signed-off-by: char-1ee <[email protected]>
1 parent 1e6efd1 commit feee72b

File tree

2 files changed

+358
-62
lines changed

2 files changed

+358
-62
lines changed

colossalai/inference/modeling/models/bloom.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from colossalai.inference.flash_decoding_utils import FDIntermTensors
88
from colossalai.shardformer.shard import ShardConfig
9-
from colossalai.kernel.triton import flash_decoding_attention, get_xine_cache
9+
from colossalai.kernel.triton import flash_decoding_attention_with_alibi
1010
from colossalai.kernel.kernel_loader import InferenceOpsLoader
1111
from colossalai.kernel.jit.bias_gelu import GeLUFunction
1212
from colossalai.kernel.jit.bias_dropout_add import bias_dropout_add_fused_inference
@@ -42,33 +42,39 @@
4242
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
4343

4444

45-
# A temporary python implementation of ALibi.
46-
def _get_bias_matrix(n_heads: int):
47-
def _get_bias_matrix_pow_of_2(n_heads):
45+
# The Alibi implementation is adapted from https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
46+
def _get_alibi_slopes(n_heads: int):
47+
def _get_alibi_slopes_pow_of_2(n_heads):
4848
start = (2 ** (-2 ** -(math.log2(n_heads) - 3)))
4949
ratio = start
5050
return [start * ratio ** i for i in range(n_heads)]
5151

5252
if math.log2(n_heads).is_integer():
53-
return _get_bias_matrix_pow_of_2(n_heads)
53+
return _get_alibi_slopes_pow_of_2(n_heads)
5454
else:
5555
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
56-
return _get_bias_matrix_pow_of_2(closest_power_of_2) + _get_bias_matrix(2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
56+
return _get_alibi_slopes_pow_of_2(closest_power_of_2) + _get_alibi_slopes(2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
5757

58-
def _fill_with_neg_inf(t):
59-
return t.float().fill_(float("-inf")).type_as(t)
58+
def _get_alibi_tensor(n_heads: int, mask: torch.Tensor):
59+
slopes = _get_alibi_slopes(n_heads).to(mask.device)
60+
distance = mask.cumsum(dim=-1)
61+
return distance[:, :, None] * slopes[None, None, :]
62+
63+
64+
# def _fill_with_neg_inf(t):
65+
# return t.float().fill_(float("-inf")).type_as(t)
6066

61-
# (Register buffer within BloomModel), only use for inference
62-
def _get_alibi_mask(max_pos: int, n_heads: int):
63-
slopes = torch.Tensor(_get_bias_matrix(n_heads))
64-
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0) \
65-
.expand(n_heads, -1, -1) \
66-
.view(n_heads, 1, max_pos)
67+
# # (Register buffer within BloomModel), only use for inference
68+
# def _get_alibi_tensor(max_pos: int, n_heads: int):
69+
# slopes = torch.Tensor(_get_alibi_slopes(n_heads))
70+
# alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0) \
71+
# .expand(n_heads, -1, -1) \
72+
# .view(n_heads, 1, max_pos)
6773

68-
alibi_mask = torch.triu (
69-
_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
70-
)
71-
return alibi_mask.unsqueeze(0) + alibi
74+
# alibi_mask = torch.triu (
75+
# _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
76+
# )
77+
# return alibi_mask.unsqueeze(0) + alibi
7278

7379

7480
# TODO
@@ -77,7 +83,6 @@ def bloom_model_forward(
7783
input_tokens_ids: torch.Tensor,
7884
output_tensor: torch.Tensor,
7985
inputmetadata: InputMetaData,
80-
attention_mask: torch.Tensor = None,
8186
k_caches: List[torch.Tensor] = None,
8287
v_caches: List[torch.Tensor] = None,
8388
use_cuda_kernel: Optional[bool] = True,
@@ -87,7 +92,7 @@ def bloom_model_forward(
8792
def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = False):
8893
if is_prompts:
8994
is_prompts = False
90-
self.register_buffer("future_mask", _get_alibi_mask())
95+
self.register_buffer("future_mask", _get_alibi_tensor())
9196

9297
is_prompts = inputmetadata.is_prompts
9398
block_tables = inputmetadata.block_tables
@@ -105,23 +110,18 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
105110
if use_cuda_kernel:
106111
if inputmetadata != torch.float32 and use_flash_attn2:
107112
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
108-
109-
# TODO: need pass deal with past_seq_length (k, v cache related)
110-
# alibi = get_alibi_mask(hidden_states)
111113

112114
seq_length_with_past = sequence_lengths
113115

114-
if is_prompts:
115-
is_prompts = False
116-
self.register_buffer("future_mask", _get_alibi_mask(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
117-
if seq_length_with_past > self.max_cache_pos:
118-
self.max_cache_pos = seq_length_with_past
119-
self.register_buffer("future_mask", _get_alibi_mask(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
120-
121-
alibi = _get_bias_matrix(self.n_head)
122-
alibi_mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
123-
attention_mask = alibi_mask # refer to baichuan_13b
116+
# if is_prompts:
117+
# is_prompts = False
118+
# self.register_buffer("future_mask", _get_alibi_tensor(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
119+
# if seq_length_with_past > self.max_cache_pos:
120+
# self.max_cache_pos = seq_length_with_past
121+
# self.register_buffer("future_mask", _get_alibi_tensor(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
124122

123+
alibi = _get_alibi_slopes(self.n_head)
124+
# alibi_mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
125125

126126
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
127127
norm_output = torch.empty_like(hidden_states)
@@ -144,10 +144,7 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
144144
sm_scale=sm_scale,
145145
use_cuda_kernel=use_cuda_kernel,
146146
high_precision=high_precision,
147-
attention_mask=attention_mask,
148147
)
149-
150-
# TODO: is_prompt
151148

152149
hidden_states = self.ln_f(hidden_states)
153150
return hidden_states
@@ -186,7 +183,6 @@ def bloom_block_forward(
186183
v_cache: torch.Tensor,
187184
sequence_lengths: torch.Tensor,
188185
fd_inter_tensor: FDIntermTensors,
189-
attention_mask: torch.Tensor = None,
190186
is_prompts: bool = True,
191187
is_verifier: bool = False,
192188
tokens_to_verify: int = None,
@@ -212,7 +208,6 @@ def bloom_block_forward(
212208
hidden_states=layernorm_output,
213209
residual=residual,
214210
alibi=alibi,
215-
attention_mask=attention_mask,
216211
hidden_states=hidden_states,
217212
block_tables=block_tables,
218213
k_cache=k_cache,
@@ -289,8 +284,7 @@ def from_native_module(module: BloomAttention, *args, **kwargs) -> BloomAttentio
289284
def forward(
290285
self,
291286
hidden_states: torch.Tensor,
292-
alibi: torch.Tensor, # alibi slopes
293-
attention_mask: torch.Tensor,
287+
alibi: torch.Tensor,
294288
block_tables: torch.Tensor,
295289
k_cache: torch.Tensor,
296290
v_cache: torch.Tensor,
@@ -331,10 +325,11 @@ def forward(
331325
)
332326

333327

334-
attn_output = flash_decoding_attention(
328+
attn_output = flash_decoding_attention_with_alibi(
335329
q=query_states,
336330
k_cache=k_cache,
337331
v_cache=v_cache,
332+
alibi=alibi,
338333
kv_seq_len=sequence_lengths,
339334
block_tables=block_tables,
340335
block_size=block_size,

0 commit comments

Comments
 (0)