6
6
)
7
7
from colossalai .inference .flash_decoding_utils import FDIntermTensors
8
8
from colossalai .shardformer .shard import ShardConfig
9
- from colossalai .kernel .triton import flash_decoding_attention , get_xine_cache
9
+ from colossalai .kernel .triton import flash_decoding_attention_with_alibi
10
10
from colossalai .kernel .kernel_loader import InferenceOpsLoader
11
11
from colossalai .kernel .jit .bias_gelu import GeLUFunction
12
12
from colossalai .kernel .jit .bias_dropout_add import bias_dropout_add_fused_inference
42
42
logger .warning (f"flash_attn2 has not been installed yet, we will use triton flash attn instead." )
43
43
44
44
45
- # A temporary python implementation of ALibi.
46
- def _get_bias_matrix (n_heads : int ):
47
- def _get_bias_matrix_pow_of_2 (n_heads ):
45
+ # The Alibi implementation is adapted from https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
46
+ def _get_alibi_slopes (n_heads : int ):
47
+ def _get_alibi_slopes_pow_of_2 (n_heads ):
48
48
start = (2 ** (- 2 ** - (math .log2 (n_heads ) - 3 )))
49
49
ratio = start
50
50
return [start * ratio ** i for i in range (n_heads )]
51
51
52
52
if math .log2 (n_heads ).is_integer ():
53
- return _get_bias_matrix_pow_of_2 (n_heads )
53
+ return _get_alibi_slopes_pow_of_2 (n_heads )
54
54
else :
55
55
closest_power_of_2 = 2 ** math .floor (math .log2 (n_heads ))
56
- return _get_bias_matrix_pow_of_2 (closest_power_of_2 ) + _get_bias_matrix (2 * closest_power_of_2 )[0 ::2 ][:n_heads - closest_power_of_2 ]
56
+ return _get_alibi_slopes_pow_of_2 (closest_power_of_2 ) + _get_alibi_slopes (2 * closest_power_of_2 )[0 ::2 ][:n_heads - closest_power_of_2 ]
57
57
58
- def _fill_with_neg_inf (t ):
59
- return t .float ().fill_ (float ("-inf" )).type_as (t )
58
+ def _get_alibi_tensor (n_heads : int , mask : torch .Tensor ):
59
+ slopes = _get_alibi_slopes (n_heads ).to (mask .device )
60
+ distance = mask .cumsum (dim = - 1 )
61
+ return distance [:, :, None ] * slopes [None , None , :]
62
+
63
+
64
+ # def _fill_with_neg_inf(t):
65
+ # return t.float().fill_(float("-inf")).type_as(t)
60
66
61
- # (Register buffer within BloomModel), only use for inference
62
- def _get_alibi_mask (max_pos : int , n_heads : int ):
63
- slopes = torch .Tensor (_get_bias_matrix (n_heads ))
64
- alibi = slopes .unsqueeze (1 ).unsqueeze (1 ) * torch .arange (max_pos ).unsqueeze (0 ).unsqueeze (0 ) \
65
- .expand (n_heads , - 1 , - 1 ) \
66
- .view (n_heads , 1 , max_pos )
67
+ # # (Register buffer within BloomModel), only use for inference
68
+ # def _get_alibi_tensor (max_pos: int, n_heads: int):
69
+ # slopes = torch.Tensor(_get_alibi_slopes (n_heads))
70
+ # alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0) \
71
+ # .expand(n_heads, -1, -1) \
72
+ # .view(n_heads, 1, max_pos)
67
73
68
- alibi_mask = torch .triu (
69
- _fill_with_neg_inf (torch .zeros ([max_pos , max_pos ])), 1
70
- )
71
- return alibi_mask .unsqueeze (0 ) + alibi
74
+ # alibi_mask = torch.triu (
75
+ # _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
76
+ # )
77
+ # return alibi_mask.unsqueeze(0) + alibi
72
78
73
79
74
80
# TODO
@@ -77,7 +83,6 @@ def bloom_model_forward(
77
83
input_tokens_ids : torch .Tensor ,
78
84
output_tensor : torch .Tensor ,
79
85
inputmetadata : InputMetaData ,
80
- attention_mask : torch .Tensor = None ,
81
86
k_caches : List [torch .Tensor ] = None ,
82
87
v_caches : List [torch .Tensor ] = None ,
83
88
use_cuda_kernel : Optional [bool ] = True ,
@@ -87,7 +92,7 @@ def bloom_model_forward(
87
92
def get_alibi_mask (x : torch .Tensor , past_seq_length : int , is_prompts : bool = False ):
88
93
if is_prompts :
89
94
is_prompts = False
90
- self .register_buffer ("future_mask" , _get_alibi_mask ())
95
+ self .register_buffer ("future_mask" , _get_alibi_tensor ())
91
96
92
97
is_prompts = inputmetadata .is_prompts
93
98
block_tables = inputmetadata .block_tables
@@ -105,23 +110,18 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
105
110
if use_cuda_kernel :
106
111
if inputmetadata != torch .float32 and use_flash_attn2 :
107
112
cu_seqlens = F .pad (torch .cumsum (sequence_lengths , dim = 0 , dtype = torch .torch .int32 ), (1 , 0 ))
108
-
109
- # TODO: need pass deal with past_seq_length (k, v cache related)
110
- # alibi = get_alibi_mask(hidden_states)
111
113
112
114
seq_length_with_past = sequence_lengths
113
115
114
- if is_prompts :
115
- is_prompts = False
116
- self .register_buffer ("future_mask" , _get_alibi_mask (self .n_head , self .max_cache_pos ).to (hidden_states ), persistent = False )
117
- if seq_length_with_past > self .max_cache_pos :
118
- self .max_cache_pos = seq_length_with_past
119
- self .register_buffer ("future_mask" , _get_alibi_mask (self .n_head , self .max_cache_pos ).to (hidden_states ), persistent = False )
120
-
121
- alibi = _get_bias_matrix (self .n_head )
122
- alibi_mask = self .future_mask [:self .n_head , :seq_length_with_past , :seq_length_with_past ]
123
- attention_mask = alibi_mask # refer to baichuan_13b
116
+ # if is_prompts:
117
+ # is_prompts = False
118
+ # self.register_buffer("future_mask", _get_alibi_tensor(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
119
+ # if seq_length_with_past > self.max_cache_pos:
120
+ # self.max_cache_pos = seq_length_with_past
121
+ # self.register_buffer("future_mask", _get_alibi_tensor(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
124
122
123
+ alibi = _get_alibi_slopes (self .n_head )
124
+ # alibi_mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
125
125
126
126
sm_scale = 1.0 / (inputmetadata .head_dim ** 0.5 )
127
127
norm_output = torch .empty_like (hidden_states )
@@ -144,10 +144,7 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
144
144
sm_scale = sm_scale ,
145
145
use_cuda_kernel = use_cuda_kernel ,
146
146
high_precision = high_precision ,
147
- attention_mask = attention_mask ,
148
147
)
149
-
150
- # TODO: is_prompt
151
148
152
149
hidden_states = self .ln_f (hidden_states )
153
150
return hidden_states
@@ -186,7 +183,6 @@ def bloom_block_forward(
186
183
v_cache : torch .Tensor ,
187
184
sequence_lengths : torch .Tensor ,
188
185
fd_inter_tensor : FDIntermTensors ,
189
- attention_mask : torch .Tensor = None ,
190
186
is_prompts : bool = True ,
191
187
is_verifier : bool = False ,
192
188
tokens_to_verify : int = None ,
@@ -212,7 +208,6 @@ def bloom_block_forward(
212
208
hidden_states = layernorm_output ,
213
209
residual = residual ,
214
210
alibi = alibi ,
215
- attention_mask = attention_mask ,
216
211
hidden_states = hidden_states ,
217
212
block_tables = block_tables ,
218
213
k_cache = k_cache ,
@@ -289,8 +284,7 @@ def from_native_module(module: BloomAttention, *args, **kwargs) -> BloomAttentio
289
284
def forward (
290
285
self ,
291
286
hidden_states : torch .Tensor ,
292
- alibi : torch .Tensor , # alibi slopes
293
- attention_mask : torch .Tensor ,
287
+ alibi : torch .Tensor ,
294
288
block_tables : torch .Tensor ,
295
289
k_cache : torch .Tensor ,
296
290
v_cache : torch .Tensor ,
@@ -331,10 +325,11 @@ def forward(
331
325
)
332
326
333
327
334
- attn_output = flash_decoding_attention (
328
+ attn_output = flash_decoding_attention_with_alibi (
335
329
q = query_states ,
336
330
k_cache = k_cache ,
337
331
v_cache = v_cache ,
332
+ alibi = alibi ,
338
333
kv_seq_len = sequence_lengths ,
339
334
block_tables = block_tables ,
340
335
block_size = block_size ,
0 commit comments