Skip to content

Commit 59ba43b

Browse files
committed
Rebase upstream commits and refactor
Signed-off-by: char-1ee <[email protected]>
1 parent 67d67fb commit 59ba43b

File tree

11 files changed

+218
-400
lines changed

11 files changed

+218
-400
lines changed

colossalai/inference/core/engine.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
PreTrainedTokenizer,
1414
PreTrainedTokenizerFast,
1515
)
16-
from transformers.models.llama.modeling_llama import LlamaForCausalLM
1716
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
17+
from transformers.models.llama.modeling_llama import LlamaForCausalLM
1818

1919
from colossalai.accelerator import get_accelerator
2020
from colossalai.cluster import ProcessGroupMesh
@@ -43,7 +43,6 @@
4343
"BloomForCausalLM": BloomForCausalLM,
4444
}
4545

46-
_alibi_models = ["bloom", "baichuan"]
4746

4847
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
4948

@@ -83,7 +82,7 @@ def __init__(
8382
self.tokenizer = tokenizer
8483
self.tokenizer.pad_token = self.tokenizer.eos_token
8584

86-
self.request_handler = RequestHandler(self.inference_config, self.model_config, alibi_attn=self.alibi_attn)
85+
self.request_handler = RequestHandler(self.inference_config, self.model_config)
8786
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
8887
# DISCUSS maybe move this into batch info?
8988

@@ -164,14 +163,6 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
164163
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
165164
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
166165

167-
self.alibi_attn = False
168-
if self.model_config.model_type in _alibi_models:
169-
# Used for bloom, baichuan 13b and baichuan2 13b.
170-
self.alibi_attn = True
171-
# Hardcode used to distinguish between baichuan 7b and baichuan 13b.(There might be a better way to handle this.)
172-
if self.model_config.model_type == "baichuan" and self.model_config.hidden_size == 4096:
173-
self.alibi_attn = False
174-
175166
self.model = self._shardformer(
176167
model,
177168
model_policy,
@@ -747,4 +738,4 @@ def step(self) -> List[str]:
747738

748739
finished_sequences = self.request_handler.update()
749740

750-
return finished_sequences
741+
return finished_sequences

colossalai/inference/kv_cache/kvcache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None):
1919
if hasattr(config, attr_name):
2020
return getattr(config, attr_name)
21-
if alter_attr is not None: # TODO, rebase caidi changes
21+
if alter_attr is not None:
2222
return alter_attr
2323
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
2424
return getattr(config, config.attribute_map[attr_name])

colossalai/inference/modeling/models/nopadding_baichuan.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.distributed import ProcessGroup
99

1010
from colossalai.inference.flash_decoding_utils import FDIntermTensors
11-
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
11+
from colossalai.inference.utils import get_alibi_slopes
1212
from colossalai.kernel.kernel_loader import InferenceOpsLoader
1313
from colossalai.kernel.triton import (
1414
context_attention_unpadded,
@@ -47,22 +47,6 @@
4747
logger = get_dist_logger(__name__)
4848

4949

50-
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
51-
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
52-
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
53-
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
54-
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
55-
slopes = torch.pow(base, powers)
56-
if closest_power_of_2 != num_heads:
57-
extra_base = torch.tensor(
58-
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
59-
)
60-
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
61-
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
62-
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
63-
return slopes
64-
65-
6650
def baichuan_rmsnorm_forward(
6751
self,
6852
hidden_states: torch.Tensor,

colossalai/inference/modeling/models/nopadding_bloom.py

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,27 @@
66

77
from colossalai.inference.config import InputMetaData
88
from colossalai.inference.flash_decoding_utils import FDIntermTensors
9+
from colossalai.inference.utils import get_alibi_slopes
910
from colossalai.kernel.jit.bias_dropout_add import bias_dropout_add_fused_inference
1011
from colossalai.kernel.jit.bias_gelu import GeLUFunction
1112
from colossalai.kernel.kernel_loader import InferenceOpsLoader
12-
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
13+
from colossalai.kernel.triton import context_attention_unpadded, copy_k_to_blocked_cache, flash_decoding_attention
1314
from colossalai.logging import get_dist_logger
1415

1516
logger = get_dist_logger(__name__)
1617

17-
inference_ops = InferenceOpsLoader.load()
18-
1918
try:
20-
pass
19+
from flash_attn import flash_attn_varlen_func
2120

2221
use_flash_attn2 = True
2322
except ImportError:
2423
use_flash_attn2 = False
2524
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
2625

26+
inference_ops = InferenceOpsLoader().load()
27+
28+
logger = get_dist_logger(__name__)
29+
2730

2831
def bloom_causal_lm_forward(
2932
self: BloomForCausalLM,
@@ -107,6 +110,7 @@ def bloom_model_forward(
107110
hidden_states = layer(
108111
hidden_states,
109112
block_tables=block_tables,
113+
is_prompts=inputmetadata.is_prompts,
110114
k_cache=k_caches[layer_id],
111115
v_cache=v_caches[layer_id],
112116
sequence_lengths=sequence_lengths,
@@ -144,7 +148,7 @@ def bloom_block_forward(
144148
use_cuda_kernel: bool = True,
145149
cu_seqlens: torch.Tensor = None,
146150
high_precision: bool = False,
147-
) -> torch.Tensor:
151+
) -> torch.FloatTensor:
148152
"""
149153
Replacement of forward function in the BloomBlock module.
150154
@@ -234,6 +238,7 @@ def __init__(
234238

235239
self.hidden_size = hidden_size
236240
self.num_heads = n_heads
241+
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
237242
self.head_dim = self.hidden_size // self.num_heads
238243
self.o_proj_w = attn_oproj_w
239244

@@ -289,7 +294,7 @@ def forward(
289294
high_precision: bool = False,
290295
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
291296
"""
292-
Forward function of the NopadBloomAttention.
297+
Forward function of the NopadBloomAttention. Current attention does not support speculative decoding.
293298
294299
Args:
295300
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
@@ -318,28 +323,73 @@ def forward(
318323

319324
block_size = k_cache.size(-2)
320325

321-
# TODO: flash attention
322-
if is_prompts: # Prefilling phase
323-
attn_output = context_attention_unpadded(
324-
q=query_states,
325-
k=key_states,
326-
v=value_states,
327-
k_cache=k_cache,
328-
v_cache=v_cache,
329-
context_lengths=sequence_lengths,
330-
block_size=block_size,
331-
block_tables=block_tables,
332-
output=output_tensor,
333-
alibi_slopes=fd_inter_tensor.alibi_slopes,
334-
max_seq_len=kv_seq_len,
335-
sm_scale=sm_scale,
336-
)
337-
else: # Decoding phase
326+
if is_prompts: # Context stage (prefilling phase)
327+
if (
328+
use_cuda_kernel
329+
and query_states.dtype != torch.float32
330+
and use_flash_attn2 # flash attn 2 currently only supports FP16/BF16
331+
):
332+
# Copy the GPU memory of kvcache during context stage
333+
inference_ops.context_kv_cache_memcpy(
334+
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
335+
)
336+
337+
attn_output = flash_attn_varlen_func(
338+
query_states,
339+
key_states,
340+
value_states,
341+
cu_seqlens_q=cu_seqlens,
342+
cu_seqlens_k=cu_seqlens,
343+
max_seqlen_q=kv_seq_len,
344+
max_seqlen_k=kv_seq_len,
345+
dropout_p=0.0,
346+
softmax_scale=sm_scale,
347+
causal=True,
348+
alibi_slopes=self.alibi_slopes,
349+
)
350+
attn_output = attn_output.view(token_nums, -1)
351+
352+
else:
353+
attn_output = context_attention_unpadded(
354+
q=query_states,
355+
k=key_states,
356+
v=value_states,
357+
k_cache=k_cache,
358+
v_cache=v_cache,
359+
context_lengths=sequence_lengths,
360+
block_size=block_size,
361+
block_tables=block_tables,
362+
output=output_tensor,
363+
alibi_slopes=self.alibi_slopes,
364+
max_seq_len=kv_seq_len,
365+
sm_scale=sm_scale,
366+
)
367+
368+
else: # Decode stage
369+
if use_cuda_kernel:
370+
# Copy the GPU memory of kvcache during decode stage
371+
inference_ops.decode_kv_cache_memcpy(
372+
key_states, value_states, k_cache, v_cache, sequence_lengths, block_size, block_tables
373+
)
374+
else:
375+
copy_k_to_blocked_cache(
376+
key_states,
377+
k_cache,
378+
kv_lengths=sequence_lengths,
379+
block_tables=block_tables,
380+
)
381+
copy_k_to_blocked_cache(
382+
value_states,
383+
v_cache,
384+
kv_lengths=sequence_lengths,
385+
block_tables=block_tables,
386+
)
387+
338388
attn_output = flash_decoding_attention(
339389
q=query_states,
340390
k_cache=k_cache,
341391
v_cache=v_cache,
342-
alibi_slopes=fd_inter_tensor.alibi_slopes,
392+
alibi_slopes=self.alibi_slopes,
343393
kv_seq_len=sequence_lengths,
344394
block_tables=block_tables,
345395
block_size=block_size,

colossalai/inference/utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
2-
Utils for model inference
2+
Utilities for model inference
33
"""
4+
import math
45
import os
56
import re
67
from pathlib import Path
@@ -55,6 +56,31 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
5556
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
5657

5758

59+
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
60+
"""
61+
Calculate the slopes for the Alibi positional encoding. The calculation is adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
62+
63+
Args:
64+
num_heads (int): The number of heads.
65+
device (torch.device): The device to perform the calculations on.
66+
67+
Returns:
68+
torch.Tensor: The calculated slopes tensor of (nheads,) or (batch_size, nheads).
69+
"""
70+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
71+
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
72+
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
73+
slopes = torch.pow(base, powers)
74+
if closest_power_of_2 != num_heads:
75+
extra_base = torch.tensor(
76+
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
77+
)
78+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
79+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
80+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
81+
return slopes
82+
83+
5884
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
5985
"""
6086
Check whether the checkpoint has an index file.

0 commit comments

Comments
 (0)