6
6
)
7
7
from colossalai .inference .flash_decoding_utils import FDIntermTensors
8
8
from colossalai .shardformer .shard import ShardConfig
9
- from colossalai .kernel .triton import flash_decoding_attention_with_alibi
9
+ from colossalai .kernel .triton import flash_decoding_attention , context_attention_unpadded
10
10
from colossalai .kernel .kernel_loader import InferenceOpsLoader
11
11
from colossalai .kernel .jit .bias_gelu import GeLUFunction
12
12
from colossalai .kernel .jit .bias_dropout_add import bias_dropout_add_fused_inference
13
13
14
14
15
15
import torch
16
16
import torch .nn .functional as F
17
+ import torch .nn as nn
17
18
from typing import List , Optional , Tuple
18
19
import math
19
20
@@ -61,26 +62,9 @@ def _get_alibi_tensor(n_heads: int, mask: torch.Tensor):
61
62
return distance [:, :, None ] * slopes [None , None , :]
62
63
63
64
64
- # def _fill_with_neg_inf(t):
65
- # return t.float().fill_(float("-inf")).type_as(t)
66
-
67
- # # (Register buffer within BloomModel), only use for inference
68
- # def _get_alibi_tensor(max_pos: int, n_heads: int):
69
- # slopes = torch.Tensor(_get_alibi_slopes(n_heads))
70
- # alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0) \
71
- # .expand(n_heads, -1, -1) \
72
- # .view(n_heads, 1, max_pos)
73
-
74
- # alibi_mask = torch.triu (
75
- # _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
76
- # )
77
- # return alibi_mask.unsqueeze(0) + alibi
78
-
79
-
80
- # TODO
81
65
def bloom_model_forward (
82
66
self : BloomModel ,
83
- input_tokens_ids : torch .Tensor ,
67
+ input_tokens_ids : torch .Tensor , # no padding
84
68
output_tensor : torch .Tensor ,
85
69
inputmetadata : InputMetaData ,
86
70
k_caches : List [torch .Tensor ] = None ,
@@ -89,10 +73,10 @@ def bloom_model_forward(
89
73
high_precision : bool = False ,
90
74
) -> torch .Tensor :
91
75
92
- def get_alibi_mask (x : torch .Tensor , past_seq_length : int , is_prompts : bool = False ):
93
- if is_prompts :
94
- is_prompts = False
95
- self .register_buffer ("future_mask" , _get_alibi_tensor ())
76
+ # def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = False):
77
+ # if is_prompts:
78
+ # is_prompts = False
79
+ # self.register_buffer("future_mask", _get_alibi_tensor())
96
80
97
81
is_prompts = inputmetadata .is_prompts
98
82
block_tables = inputmetadata .block_tables
@@ -120,7 +104,7 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
120
104
# self.max_cache_pos = seq_length_with_past
121
105
# self.register_buffer("future_mask", _get_alibi_tensor(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
122
106
123
- alibi = _get_alibi_slopes (self .n_head )
107
+ # alibi = _get_alibi_slopes(self.num_heads )
124
108
# alibi_mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
125
109
126
110
sm_scale = 1.0 / (inputmetadata .head_dim ** 0.5 )
@@ -129,7 +113,6 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
129
113
for layer_id , layer in enumerate (self .h ):
130
114
hidden_states = layer (
131
115
hidden_states ,
132
- alibi = alibi ,
133
116
block_tables = block_tables ,
134
117
k_cache = k_caches [layer_id ],
135
118
v_cache = v_caches [layer_id ],
@@ -138,8 +121,6 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
138
121
fd_inter_tensor = inputmetadata .fd_inter_tensor ,
139
122
kv_seq_len = kv_seq_len ,
140
123
output_tensor = output_tensor ,
141
- use_cuda_kernel = use_cuda_kernel ,
142
- high_precision = high_precision ,
143
124
norm_output = norm_output ,
144
125
sm_scale = sm_scale ,
145
126
use_cuda_kernel = use_cuda_kernel ,
@@ -160,7 +141,7 @@ def bloom_causal_lm_forward(
160
141
) -> torch .Tensor :
161
142
162
143
hidden_states = bloom_model_forward (
163
- self .model ,
144
+ self .transformer ,
164
145
input_tokens_ids = input_tokens_ids ,
165
146
output_tensor = output_tensor ,
166
147
inputmetadata = inputmetadata ,
@@ -173,11 +154,9 @@ def bloom_causal_lm_forward(
173
154
return logits
174
155
175
156
176
- # TODO
177
157
def bloom_block_forward (
178
158
self : BloomBlock ,
179
159
hidden_states : torch .Tensor ,
180
- alibi : torch .Tensor ,
181
160
block_tables : torch .Tensor ,
182
161
k_cache : torch .Tensor ,
183
162
v_cache : torch .Tensor ,
@@ -204,17 +183,14 @@ def bloom_block_forward(
204
183
residual = hidden_states
205
184
206
185
# Self attention
207
- attn_output , _ = self .self_attention (
186
+ attn_output = self .self_attention (
208
187
hidden_states = layernorm_output ,
209
- residual = residual ,
210
- alibi = alibi ,
211
- hidden_states = hidden_states ,
212
188
block_tables = block_tables ,
213
189
k_cache = k_cache ,
214
190
v_cache = v_cache ,
215
191
is_prompts = is_prompts ,
216
- is_verifier = is_verifier ,
217
- tokens_to_verify = tokens_to_verify ,
192
+ # is_verifier=is_verifier,
193
+ # tokens_to_verify=tokens_to_verify,
218
194
sequence_lengths = sequence_lengths ,
219
195
fd_inter_tensor = fd_inter_tensor ,
220
196
kv_seq_len = kv_seq_len ,
@@ -233,46 +209,50 @@ def bloom_block_forward(
233
209
else :
234
210
residual = attn_output
235
211
236
- # MLP
237
- output = self .mlp (layernorm_output , residual ) # including residuals
212
+ print (f"[DEBUG] Show attn_output shape: { attn_output .shape } , \
213
+ show residual shape: { residual .shape } \
214
+ " )
215
+
216
+ # MLP (including residuals)
217
+ output = self .mlp (layernorm_output , residual )
238
218
239
219
return output
240
220
241
-
242
- # TODO
243
- class ColossalInferBloomAttention (BloomAttention ):
221
+
222
+ class NopadBloomAttention (nn .Module ):
244
223
def __init__ (
245
224
self ,
246
- config : BloomConfig ,
225
+ hidden_size : int ,
226
+ n_heads : int ,
247
227
attn_qproj_w : torch .Tensor = None ,
248
228
attn_kproj_w : torch .Tensor = None ,
249
229
attn_vproj_w : torch .Tensor = None ,
250
230
attn_oproj_w : torch .Tensor = None ,
251
231
):
252
- super ().__init__ (config )
253
- self .q_proj_weight = attn_qproj_w
254
- self .k_proj_weight = attn_kproj_w
255
- self .v_proj_weight = attn_vproj_w
256
- self .o_proj_weight = attn_oproj_w
257
-
258
- qkv_weight_list = [self .q_proj_weight , self .k_proj_weight , self .v_proj_weight ]
259
- self .qkv_weight = torch .stack (qkv_weight_list , dim = 0 )
232
+ super ().__init__ ()
260
233
261
- # garbage collection
262
- self .q_proj = None
263
- self .k_proj = None
264
- self .v_proj = None
234
+ self .hidden_size = hidden_size
235
+ self .num_heads = n_heads
236
+ self .head_dim = self .hidden_size // self .num_heads
237
+ self .o_proj_w = attn_oproj_w
238
+
239
+ qkv_weight_list = [attn_qproj_w , attn_kproj_w , attn_vproj_w ]
240
+ self .qkv_weight = torch .stack (qkv_weight_list , dim = 0 )
265
241
266
242
@staticmethod
267
- def from_native_module (module : BloomAttention , * args , ** kwargs ) -> BloomAttention :
268
- config = module .config
269
- attn_qproj_w = module .q_proj .weight .transpose (0 , 1 )
270
- attn_kproj_w = module .k_proj .weight .transpose (0 , 1 )
271
- attn_vproj_w = module .v_proj .weight .transpose (0 , 1 )
272
- attn_oproj_w = module .o_proj .weight .transpose (0 , 1 )
243
+ def from_native_module (module : nn .Module , * args , ** kwargs ) -> "NopadBloomAttention" :
244
+ hidden_size = module .hidden_size
245
+ num_heads = module .num_heads
246
+ q_proj_w , k_proj_w , v_proj_w = module .query_key_value .weight .view ((3 , hidden_size , hidden_size ))
273
247
274
- attn_layer = ColossalInferBloomAttention (
275
- config = config ,
248
+ attn_qproj_w = q_proj_w .transpose (0 , 1 )
249
+ attn_kproj_w = k_proj_w .transpose (0 , 1 )
250
+ attn_vproj_w = v_proj_w .transpose (0 , 1 )
251
+ attn_oproj_w = module .dense .weight .transpose (0 , 1 )
252
+
253
+ attn_layer = NopadBloomAttention (
254
+ hidden_size = hidden_size ,
255
+ n_heads = num_heads ,
276
256
attn_qproj_w = attn_qproj_w ,
277
257
attn_kproj_w = attn_kproj_w ,
278
258
attn_vproj_w = attn_vproj_w ,
@@ -284,7 +264,6 @@ def from_native_module(module: BloomAttention, *args, **kwargs) -> BloomAttentio
284
264
def forward (
285
265
self ,
286
266
hidden_states : torch .Tensor ,
287
- alibi : torch .Tensor ,
288
267
block_tables : torch .Tensor ,
289
268
k_cache : torch .Tensor ,
290
269
v_cache : torch .Tensor ,
@@ -297,39 +276,38 @@ def forward(
297
276
use_cuda_kernel : bool = True ,
298
277
cu_seqlens : torch .Tensor = None ,
299
278
high_precision : bool = False ,
300
- ):
279
+ ) -> Tuple [ torch . Tensor , Optional [ torch . Tensor ], Optional [ Tuple [ torch . Tensor ]]] :
301
280
302
281
token_nums = hidden_states .size (0 )
303
-
304
282
hidden_states = hidden_states .expand (3 , - 1 , - 1 )
305
283
query_states , key_states , value_states = (
306
284
torch .bmm (hidden_states , self .qkv_weight ).view (3 , token_nums , self .num_heads , self .head_dim ).unbind (0 )
307
285
)
308
286
309
287
block_size = k_cache .size (- 2 )
310
288
311
- if is_prompts : # Prefilling
312
-
313
- # TODO context stage alibi flash_attn
314
- pass
315
-
316
- else : # Decoding
317
-
318
- # If alibi in this way, then next step is to softmax with matmul_result,
319
- # so do I need consider how to utilize the matmul_result
320
- matmul_result = alibi .baddbmm (
321
- batch1 = query_states ,
322
- batch2 = key_states ,
323
- beta = self .beta ,
324
- alpha = self .inv_norm_factor ,
289
+ if is_prompts :
290
+ # TODO(char-1ee) Integrate context stage flash attention with alibi encoding
291
+ attn_output = context_attention_unpadded (
292
+ q = query_states ,
293
+ k = key_states ,
294
+ v = value_states ,
295
+ k_cache = k_cache ,
296
+ v_cache = v_cache ,
297
+ context_lengths = sequence_lengths ,
298
+ block_size = block_size ,
299
+ block_tables = block_tables ,
300
+ output = output_tensor ,
301
+ alibi_slopes = fd_inter_tensor .alibi_slopes ,
302
+ max_seq_len = kv_seq_len ,
303
+ sm_scale = sm_scale ,
325
304
)
326
-
327
-
328
- attn_output = flash_decoding_attention_with_alibi (
305
+ else :
306
+ attn_output = flash_decoding_attention (
329
307
q = query_states ,
330
308
k_cache = k_cache ,
331
309
v_cache = v_cache ,
332
- alibi = alibi ,
310
+ alibi_slopes = fd_inter_tensor . alibi_slopes ,
333
311
kv_seq_len = sequence_lengths ,
334
312
block_tables = block_tables ,
335
313
block_size = block_size ,
@@ -341,23 +319,30 @@ def forward(
341
319
)
342
320
343
321
attn_output = attn_output .view (- 1 , self .hidden_size )
344
- attn_output = torch .mm (attn_output , self .o_proj_weight )
345
-
322
+ attn_output = torch .mm (attn_output , self .o_proj_w )
346
323
return attn_output
347
324
348
325
349
- class ColossalInferBloomMLP (BloomMLP ):
350
- def __init__ (self , config : BloomConfig ):
351
- super ().__init__ (config )
326
+ class NopadBloomMLP (nn .Module ):
327
+ def __init__ (self , hidden_size : int = 64 , hidden_dropout : float = 0.0 ):
328
+ super ().__init__ ()
329
+ self .hidden_size = hidden_size
330
+ self .hidden_dropout = hidden_dropout
331
+ self .dense_h_to_4h = nn .Linear (hidden_size , hidden_size * 4 )
352
332
self .gelu_impl = GeLUFunction .apply
333
+ self .dense_4h_to_h = nn .Linear (hidden_size * 4 , hidden_size )
334
+
335
+ self .dense_h_to_4h = self .dense_h_to_4h .half ()
336
+ self .dense_4h_to_h = self .dense_4h_to_h .half ()
353
337
354
338
@staticmethod
355
- def from_native_method (module : BloomMLP , * args , ** kwargs ) -> BloomMLP :
356
- config = module . config
357
- mlp_layer = ColossalInferBloomMLP ( config = config )
339
+ def from_native_module (module : nn . Module , * args , ** kwargs ) -> "NopadBloomMLP" :
340
+ hidden_size = 64 # TODO: hyperparameters
341
+ mlp_layer = NopadBloomMLP ( hidden_size = hidden_size , hidden_dropout = module . hidden_dropout )
358
342
return mlp_layer
359
343
360
344
def forward (self , hidden_states : torch .Tensor , residual : torch .Tensor ) -> torch .Tensor :
345
+ print (f"[DEBUG] Print shape of hidden_states: { hidden_states .shape } , and dtype is { hidden_states .dtype } " )
361
346
hidden_states = self .dense_h_to_4h (hidden_states )
362
347
bias = torch .zero_like (hidden_states )
363
348
hidden_states = self .gelu_impl (hidden_states , bias )
0 commit comments