|
6 | 6 |
|
7 | 7 | from colossalai.inference.config import InputMetaData
|
8 | 8 | from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
| 9 | +from colossalai.inference.utils import get_alibi_slopes |
9 | 10 | from colossalai.kernel.jit.bias_dropout_add import bias_dropout_add_fused_inference
|
10 | 11 | from colossalai.kernel.jit.bias_gelu import GeLUFunction
|
11 | 12 | from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
12 |
| -from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention |
| 13 | +from colossalai.kernel.triton import context_attention_unpadded, copy_k_to_blocked_cache, flash_decoding_attention |
13 | 14 | from colossalai.logging import get_dist_logger
|
14 | 15 |
|
15 | 16 | logger = get_dist_logger(__name__)
|
16 | 17 |
|
17 |
| -inference_ops = InferenceOpsLoader.load() |
18 |
| - |
19 | 18 | try:
|
20 |
| - pass |
| 19 | + from flash_attn import flash_attn_varlen_func |
21 | 20 |
|
22 | 21 | use_flash_attn2 = True
|
23 | 22 | except ImportError:
|
24 | 23 | use_flash_attn2 = False
|
25 | 24 | logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
26 | 25 |
|
| 26 | +inference_ops = InferenceOpsLoader().load() |
| 27 | + |
| 28 | +logger = get_dist_logger(__name__) |
| 29 | + |
27 | 30 |
|
28 | 31 | def bloom_causal_lm_forward(
|
29 | 32 | self: BloomForCausalLM,
|
@@ -107,6 +110,7 @@ def bloom_model_forward(
|
107 | 110 | hidden_states = layer(
|
108 | 111 | hidden_states,
|
109 | 112 | block_tables=block_tables,
|
| 113 | + is_prompts=inputmetadata.is_prompts, |
110 | 114 | k_cache=k_caches[layer_id],
|
111 | 115 | v_cache=v_caches[layer_id],
|
112 | 116 | sequence_lengths=sequence_lengths,
|
@@ -144,7 +148,7 @@ def bloom_block_forward(
|
144 | 148 | use_cuda_kernel: bool = True,
|
145 | 149 | cu_seqlens: torch.Tensor = None,
|
146 | 150 | high_precision: bool = False,
|
147 |
| -) -> torch.Tensor: |
| 151 | +) -> torch.FloatTensor: |
148 | 152 | """
|
149 | 153 | Replacement of forward function in the BloomBlock module.
|
150 | 154 |
|
@@ -234,6 +238,7 @@ def __init__(
|
234 | 238 |
|
235 | 239 | self.hidden_size = hidden_size
|
236 | 240 | self.num_heads = n_heads
|
| 241 | + self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device) |
237 | 242 | self.head_dim = self.hidden_size // self.num_heads
|
238 | 243 | self.o_proj_w = attn_oproj_w
|
239 | 244 |
|
@@ -289,7 +294,7 @@ def forward(
|
289 | 294 | high_precision: bool = False,
|
290 | 295 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
291 | 296 | """
|
292 |
| - Forward function of the NopadBloomAttention. |
| 297 | + Forward function of the NopadBloomAttention. Current attention does not support speculative decoding. |
293 | 298 |
|
294 | 299 | Args:
|
295 | 300 | hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
@@ -318,28 +323,73 @@ def forward(
|
318 | 323 |
|
319 | 324 | block_size = k_cache.size(-2)
|
320 | 325 |
|
321 |
| - # TODO: flash attention |
322 |
| - if is_prompts: # Prefilling phase |
323 |
| - attn_output = context_attention_unpadded( |
324 |
| - q=query_states, |
325 |
| - k=key_states, |
326 |
| - v=value_states, |
327 |
| - k_cache=k_cache, |
328 |
| - v_cache=v_cache, |
329 |
| - context_lengths=sequence_lengths, |
330 |
| - block_size=block_size, |
331 |
| - block_tables=block_tables, |
332 |
| - output=output_tensor, |
333 |
| - alibi_slopes=fd_inter_tensor.alibi_slopes, |
334 |
| - max_seq_len=kv_seq_len, |
335 |
| - sm_scale=sm_scale, |
336 |
| - ) |
337 |
| - else: # Decoding phase |
| 326 | + if is_prompts: # Context stage (prefilling phase) |
| 327 | + if ( |
| 328 | + use_cuda_kernel |
| 329 | + and query_states.dtype != torch.float32 |
| 330 | + and use_flash_attn2 # flash attn 2 currently only supports FP16/BF16 |
| 331 | + ): |
| 332 | + # Copy the GPU memory of kvcache during context stage |
| 333 | + inference_ops.context_kv_cache_memcpy( |
| 334 | + key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len |
| 335 | + ) |
| 336 | + |
| 337 | + attn_output = flash_attn_varlen_func( |
| 338 | + query_states, |
| 339 | + key_states, |
| 340 | + value_states, |
| 341 | + cu_seqlens_q=cu_seqlens, |
| 342 | + cu_seqlens_k=cu_seqlens, |
| 343 | + max_seqlen_q=kv_seq_len, |
| 344 | + max_seqlen_k=kv_seq_len, |
| 345 | + dropout_p=0.0, |
| 346 | + softmax_scale=sm_scale, |
| 347 | + causal=True, |
| 348 | + alibi_slopes=self.alibi_slopes, |
| 349 | + ) |
| 350 | + attn_output = attn_output.view(token_nums, -1) |
| 351 | + |
| 352 | + else: |
| 353 | + attn_output = context_attention_unpadded( |
| 354 | + q=query_states, |
| 355 | + k=key_states, |
| 356 | + v=value_states, |
| 357 | + k_cache=k_cache, |
| 358 | + v_cache=v_cache, |
| 359 | + context_lengths=sequence_lengths, |
| 360 | + block_size=block_size, |
| 361 | + block_tables=block_tables, |
| 362 | + output=output_tensor, |
| 363 | + alibi_slopes=self.alibi_slopes, |
| 364 | + max_seq_len=kv_seq_len, |
| 365 | + sm_scale=sm_scale, |
| 366 | + ) |
| 367 | + |
| 368 | + else: # Decode stage |
| 369 | + if use_cuda_kernel: |
| 370 | + # Copy the GPU memory of kvcache during decode stage |
| 371 | + inference_ops.decode_kv_cache_memcpy( |
| 372 | + key_states, value_states, k_cache, v_cache, sequence_lengths, block_size, block_tables |
| 373 | + ) |
| 374 | + else: |
| 375 | + copy_k_to_blocked_cache( |
| 376 | + key_states, |
| 377 | + k_cache, |
| 378 | + kv_lengths=sequence_lengths, |
| 379 | + block_tables=block_tables, |
| 380 | + ) |
| 381 | + copy_k_to_blocked_cache( |
| 382 | + value_states, |
| 383 | + v_cache, |
| 384 | + kv_lengths=sequence_lengths, |
| 385 | + block_tables=block_tables, |
| 386 | + ) |
| 387 | + |
338 | 388 | attn_output = flash_decoding_attention(
|
339 | 389 | q=query_states,
|
340 | 390 | k_cache=k_cache,
|
341 | 391 | v_cache=v_cache,
|
342 |
| - alibi_slopes=fd_inter_tensor.alibi_slopes, |
| 392 | + alibi_slopes=self.alibi_slopes, |
343 | 393 | kv_seq_len=sequence_lengths,
|
344 | 394 | block_tables=block_tables,
|
345 | 395 | block_size=block_size,
|
|
0 commit comments