1
1
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
2
+ import math
2
3
from typing import Optional , Tuple
3
4
4
5
import torch
5
6
import torch .nn as nn
6
7
7
8
from colossalai .inference .flash_decoding_utils import FDIntermTensors
8
- from colossalai .inference .modeling .models .nopadding_llama import NopadLlamaAttention
9
9
from colossalai .kernel .kernel_loader import InferenceOpsLoader
10
+ from colossalai .kernel .triton import (
11
+ context_attention_unpadded ,
12
+ copy_k_to_blocked_cache ,
13
+ decoding_fused_rotary_embedding ,
14
+ flash_decoding_attention ,
15
+ rms_layernorm ,
16
+ rotary_embedding ,
17
+ )
10
18
from colossalai .logging import get_dist_logger
11
19
20
+ logger = get_dist_logger (__name__ )
21
+
22
+ try :
23
+ from flash_attn import flash_attn_varlen_func
24
+
25
+ use_flash_attn2 = True
26
+ except ImportError :
27
+ use_flash_attn2 = False
28
+ logger .warning (f"flash_attn2 has not been installed yet, we will use triton flash attn instead." )
29
+
12
30
inference_ops = InferenceOpsLoader ().load ()
13
31
14
32
logger = get_dist_logger (__name__ )
15
33
16
34
35
+ # alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
36
+ def get_alibi_slopes (num_heads : int , device : torch .device ) -> torch .Tensor :
37
+ closest_power_of_2 = 2 ** math .floor (math .log2 (num_heads ))
38
+ base = torch .tensor (2 ** (- (2 ** - (math .log2 (closest_power_of_2 ) - 3 ))), dtype = torch .float32 , device = device )
39
+ powers = torch .arange (1 , 1 + closest_power_of_2 , dtype = torch .int32 , device = device )
40
+ slopes = torch .pow (base , powers )
41
+ if closest_power_of_2 != num_heads :
42
+ extra_base = torch .tensor (
43
+ 2 ** (- (2 ** - (math .log2 (2 * closest_power_of_2 ) - 3 ))), dtype = torch .float32 , device = device
44
+ )
45
+ num_remaining_heads = min (closest_power_of_2 , num_heads - closest_power_of_2 )
46
+ extra_powers = torch .arange (1 , 1 + 2 * num_remaining_heads , 2 , dtype = torch .int32 , device = device )
47
+ slopes = torch .cat ([slopes , torch .pow (extra_base , extra_powers )], dim = 0 )
48
+ return slopes
49
+
50
+
51
+ def baichuan_rmsnorm_forward (
52
+ self ,
53
+ hidden_states : torch .Tensor ,
54
+ norm_output : torch .Tensor ,
55
+ residual : torch .Tensor = None ,
56
+ use_cuda_kernel : bool = True ,
57
+ ):
58
+ # Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b.
59
+ if hasattr (self , "variance_epsilon" ):
60
+ eps = self .variance_epsilon
61
+ elif hasattr (self , "epsilon" ):
62
+ eps = self .epsilon
63
+ else :
64
+ TypeError (
65
+ "Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
66
+ )
67
+
68
+ if use_cuda_kernel :
69
+ if residual is not None :
70
+ inference_ops .fused_add_rms_layernorm (hidden_states , residual , self .weight .data , eps )
71
+ return hidden_states , residual
72
+
73
+ if norm_output is None :
74
+ norm_output = torch .empty_like (hidden_states )
75
+ inference_ops .rms_layernorm (norm_output , hidden_states , self .weight .data , eps )
76
+ return norm_output , hidden_states
77
+ else :
78
+ return rms_layernorm (hidden_states , self .weight .data , eps , norm_output , residual )
79
+
80
+
17
81
class NopadBaichuanAttention (nn .Module ):
18
82
def __init__ (
19
83
self ,
@@ -39,9 +103,11 @@ def __init__(
39
103
self .hidden_size = config .hidden_size
40
104
self .num_heads = config .num_attention_heads
41
105
self .head_dim = self .hidden_size // self .num_heads
42
-
43
- # Used to adapt llama_base_attn_forward
44
- self .num_key_value_heads = self .num_heads
106
+ self .alibi_slopes = None
107
+ self .use_alibi_attn = False
108
+ if self .hidden_size == 5120 :
109
+ self .use_alibi_attn = True
110
+ self .alibi_slopes = get_alibi_slopes (self .num_heads , device = attn_qproj_w .device )
45
111
46
112
qkv_weight_list = [attn_qproj_w , attn_kproj_w , attn_vproj_w ]
47
113
self .qkv_weight = torch .stack (qkv_weight_list , dim = 0 )
@@ -112,26 +178,124 @@ def forward(
112
178
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
113
179
"""
114
180
115
- return NopadLlamaAttention .forward (
116
- self ,
117
- hidden_states = hidden_states ,
118
- block_tables = block_tables ,
119
- k_cache = k_cache ,
120
- v_cache = v_cache ,
121
- sequence_lengths = sequence_lengths ,
122
- cos_sin = cos_sin ,
123
- fd_inter_tensor = fd_inter_tensor ,
124
- is_prompts = is_prompts ,
125
- is_verifier = is_verifier ,
126
- tokens_to_verify = tokens_to_verify ,
127
- kv_seq_len = kv_seq_len ,
128
- output_tensor = output_tensor ,
129
- sm_scale = sm_scale ,
130
- use_cuda_kernel = use_cuda_kernel ,
131
- cu_seqlens = cu_seqlens ,
132
- high_precision = high_precision ,
181
+ token_nums = hidden_states .size (0 )
182
+ # fused qkv
183
+ hidden_states = hidden_states .expand (3 , - 1 , - 1 )
184
+ query_states , key_states , value_states = (
185
+ torch .bmm (hidden_states , self .qkv_weight ).view (3 , token_nums , self .num_heads , self .head_dim ).unbind (0 )
133
186
)
134
187
188
+ block_size = k_cache .size (- 2 )
189
+
190
+ if is_prompts :
191
+ if (
192
+ not is_verifier
193
+ and use_cuda_kernel
194
+ and query_states .dtype != torch .float32
195
+ and use_flash_attn2
196
+ and not self .use_alibi_attn
197
+ ):
198
+ # flash attn 2 currently only supports FP16/BF16.
199
+ inference_ops .rotary_embedding (query_states , key_states , cos_sin [0 ], cos_sin [1 ], high_precision )
200
+ inference_ops .context_kv_cache_memcpy (
201
+ key_states , value_states , k_cache , v_cache , sequence_lengths , cu_seqlens , block_tables , kv_seq_len
202
+ )
203
+
204
+ attn_output = flash_attn_varlen_func (
205
+ query_states ,
206
+ key_states ,
207
+ value_states ,
208
+ cu_seqlens_q = cu_seqlens ,
209
+ cu_seqlens_k = cu_seqlens ,
210
+ max_seqlen_q = kv_seq_len ,
211
+ max_seqlen_k = kv_seq_len ,
212
+ dropout_p = 0.0 ,
213
+ softmax_scale = sm_scale ,
214
+ causal = True ,
215
+ )
216
+ attn_output = attn_output .view (token_nums , - 1 )
217
+ else :
218
+ if not self .use_alibi_attn :
219
+ rotary_embedding (query_states , key_states , cos_sin [0 ], cos_sin [1 ])
220
+ attn_output = context_attention_unpadded (
221
+ q = query_states ,
222
+ k = key_states ,
223
+ v = value_states ,
224
+ k_cache = k_cache ,
225
+ v_cache = v_cache ,
226
+ context_lengths = sequence_lengths ,
227
+ block_tables = block_tables ,
228
+ block_size = block_size ,
229
+ output = output_tensor ,
230
+ alibi_slopes = self .alibi_slopes ,
231
+ max_seq_len = kv_seq_len ,
232
+ sm_scale = sm_scale ,
233
+ )
234
+ else :
235
+ q_len = tokens_to_verify + 1 if is_verifier else 1
236
+
237
+ if use_cuda_kernel :
238
+ if not self .use_alibi_attn :
239
+ inference_ops .rotary_embedding_and_cache_copy (
240
+ query_states ,
241
+ key_states ,
242
+ value_states ,
243
+ cos_sin [0 ],
244
+ cos_sin [1 ],
245
+ k_cache ,
246
+ v_cache ,
247
+ sequence_lengths ,
248
+ block_tables ,
249
+ high_precision ,
250
+ )
251
+ else :
252
+ inference_ops .decode_kv_cache_memcpy (
253
+ key_states , value_states , k_cache , v_cache , sequence_lengths , block_tables
254
+ )
255
+ else :
256
+ if not is_verifier and not self .use_alibi_attn :
257
+ decoding_fused_rotary_embedding (
258
+ query_states ,
259
+ key_states ,
260
+ value_states ,
261
+ cos_sin [0 ],
262
+ cos_sin [1 ],
263
+ k_cache ,
264
+ v_cache ,
265
+ block_tables ,
266
+ sequence_lengths ,
267
+ )
268
+ else :
269
+ if not self .use_alibi_attn :
270
+ rotary_embedding (query_states , key_states , cos_sin [0 ], cos_sin [1 ])
271
+ copy_k_to_blocked_cache (
272
+ key_states , k_cache , kv_lengths = sequence_lengths , block_tables = block_tables , n = q_len
273
+ )
274
+ copy_k_to_blocked_cache (
275
+ value_states , v_cache , kv_lengths = sequence_lengths , block_tables = block_tables , n = q_len
276
+ )
277
+
278
+ attn_output = flash_decoding_attention (
279
+ q = query_states ,
280
+ k_cache = k_cache ,
281
+ v_cache = v_cache ,
282
+ kv_seq_len = sequence_lengths ,
283
+ block_tables = block_tables ,
284
+ block_size = block_size ,
285
+ max_seq_len_in_batch = kv_seq_len ,
286
+ output = output_tensor ,
287
+ mid_output = fd_inter_tensor .mid_output ,
288
+ mid_output_lse = fd_inter_tensor .mid_output_lse ,
289
+ alibi_slopes = self .alibi_slopes ,
290
+ sm_scale = sm_scale ,
291
+ q_len = q_len ,
292
+ )
293
+
294
+ attn_output = attn_output .view (- 1 , self .hidden_size )
295
+ attn_output = torch .mm (attn_output , self .o_proj_weight )
296
+
297
+ return attn_output
298
+
135
299
136
300
# NOTE This will cause difference as out length increases.
137
301
class NopadBaichuanMLP (nn .Module ):
0 commit comments