Skip to content

Commit 3c91e3f

Browse files
authored
[Inference]Adapt to baichuan2 13B (#5614)
* adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.py
1 parent f342a93 commit 3c91e3f

File tree

10 files changed

+786
-134
lines changed

10 files changed

+786
-134
lines changed

colossalai/inference/flash_decoding_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,5 @@ def initialize(
6060
self._mid_output_lse = torch.empty(
6161
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
6262
)
63+
6364
self._tensors_initialized = True

colossalai/inference/kv_cache/kvcache_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
6464
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
6565
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
6666
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
67-
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
6867
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
68+
69+
if hasattr(config, "num_key_value_heads"):
70+
self.kv_head_num = getattr(config, "num_key_value_heads")
71+
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
72+
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
73+
else:
74+
self.kv_head_num = self.head_num
75+
6976
assert (
7077
self.kv_head_num % self.tp_size == 0
7178
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"

colossalai/inference/modeling/models/nopadding_baichuan.py

Lines changed: 186 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,83 @@
11
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
2+
import math
23
from typing import Optional, Tuple
34

45
import torch
56
import torch.nn as nn
67

78
from colossalai.inference.flash_decoding_utils import FDIntermTensors
8-
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
99
from colossalai.kernel.kernel_loader import InferenceOpsLoader
10+
from colossalai.kernel.triton import (
11+
context_attention_unpadded,
12+
copy_k_to_blocked_cache,
13+
decoding_fused_rotary_embedding,
14+
flash_decoding_attention,
15+
rms_layernorm,
16+
rotary_embedding,
17+
)
1018
from colossalai.logging import get_dist_logger
1119

20+
logger = get_dist_logger(__name__)
21+
22+
try:
23+
from flash_attn import flash_attn_varlen_func
24+
25+
use_flash_attn2 = True
26+
except ImportError:
27+
use_flash_attn2 = False
28+
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
29+
1230
inference_ops = InferenceOpsLoader().load()
1331

1432
logger = get_dist_logger(__name__)
1533

1634

35+
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
36+
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
37+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
38+
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
39+
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
40+
slopes = torch.pow(base, powers)
41+
if closest_power_of_2 != num_heads:
42+
extra_base = torch.tensor(
43+
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
44+
)
45+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
46+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
47+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
48+
return slopes
49+
50+
51+
def baichuan_rmsnorm_forward(
52+
self,
53+
hidden_states: torch.Tensor,
54+
norm_output: torch.Tensor,
55+
residual: torch.Tensor = None,
56+
use_cuda_kernel: bool = True,
57+
):
58+
# Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b.
59+
if hasattr(self, "variance_epsilon"):
60+
eps = self.variance_epsilon
61+
elif hasattr(self, "epsilon"):
62+
eps = self.epsilon
63+
else:
64+
TypeError(
65+
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
66+
)
67+
68+
if use_cuda_kernel:
69+
if residual is not None:
70+
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
71+
return hidden_states, residual
72+
73+
if norm_output is None:
74+
norm_output = torch.empty_like(hidden_states)
75+
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, eps)
76+
return norm_output, hidden_states
77+
else:
78+
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
79+
80+
1781
class NopadBaichuanAttention(nn.Module):
1882
def __init__(
1983
self,
@@ -39,9 +103,11 @@ def __init__(
39103
self.hidden_size = config.hidden_size
40104
self.num_heads = config.num_attention_heads
41105
self.head_dim = self.hidden_size // self.num_heads
42-
43-
# Used to adapt llama_base_attn_forward
44-
self.num_key_value_heads = self.num_heads
106+
self.alibi_slopes = None
107+
self.use_alibi_attn = False
108+
if self.hidden_size == 5120:
109+
self.use_alibi_attn = True
110+
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
45111

46112
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
47113
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
@@ -112,26 +178,124 @@ def forward(
112178
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
113179
"""
114180

115-
return NopadLlamaAttention.forward(
116-
self,
117-
hidden_states=hidden_states,
118-
block_tables=block_tables,
119-
k_cache=k_cache,
120-
v_cache=v_cache,
121-
sequence_lengths=sequence_lengths,
122-
cos_sin=cos_sin,
123-
fd_inter_tensor=fd_inter_tensor,
124-
is_prompts=is_prompts,
125-
is_verifier=is_verifier,
126-
tokens_to_verify=tokens_to_verify,
127-
kv_seq_len=kv_seq_len,
128-
output_tensor=output_tensor,
129-
sm_scale=sm_scale,
130-
use_cuda_kernel=use_cuda_kernel,
131-
cu_seqlens=cu_seqlens,
132-
high_precision=high_precision,
181+
token_nums = hidden_states.size(0)
182+
# fused qkv
183+
hidden_states = hidden_states.expand(3, -1, -1)
184+
query_states, key_states, value_states = (
185+
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
133186
)
134187

188+
block_size = k_cache.size(-2)
189+
190+
if is_prompts:
191+
if (
192+
not is_verifier
193+
and use_cuda_kernel
194+
and query_states.dtype != torch.float32
195+
and use_flash_attn2
196+
and not self.use_alibi_attn
197+
):
198+
# flash attn 2 currently only supports FP16/BF16.
199+
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
200+
inference_ops.context_kv_cache_memcpy(
201+
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
202+
)
203+
204+
attn_output = flash_attn_varlen_func(
205+
query_states,
206+
key_states,
207+
value_states,
208+
cu_seqlens_q=cu_seqlens,
209+
cu_seqlens_k=cu_seqlens,
210+
max_seqlen_q=kv_seq_len,
211+
max_seqlen_k=kv_seq_len,
212+
dropout_p=0.0,
213+
softmax_scale=sm_scale,
214+
causal=True,
215+
)
216+
attn_output = attn_output.view(token_nums, -1)
217+
else:
218+
if not self.use_alibi_attn:
219+
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
220+
attn_output = context_attention_unpadded(
221+
q=query_states,
222+
k=key_states,
223+
v=value_states,
224+
k_cache=k_cache,
225+
v_cache=v_cache,
226+
context_lengths=sequence_lengths,
227+
block_tables=block_tables,
228+
block_size=block_size,
229+
output=output_tensor,
230+
alibi_slopes=self.alibi_slopes,
231+
max_seq_len=kv_seq_len,
232+
sm_scale=sm_scale,
233+
)
234+
else:
235+
q_len = tokens_to_verify + 1 if is_verifier else 1
236+
237+
if use_cuda_kernel:
238+
if not self.use_alibi_attn:
239+
inference_ops.rotary_embedding_and_cache_copy(
240+
query_states,
241+
key_states,
242+
value_states,
243+
cos_sin[0],
244+
cos_sin[1],
245+
k_cache,
246+
v_cache,
247+
sequence_lengths,
248+
block_tables,
249+
high_precision,
250+
)
251+
else:
252+
inference_ops.decode_kv_cache_memcpy(
253+
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
254+
)
255+
else:
256+
if not is_verifier and not self.use_alibi_attn:
257+
decoding_fused_rotary_embedding(
258+
query_states,
259+
key_states,
260+
value_states,
261+
cos_sin[0],
262+
cos_sin[1],
263+
k_cache,
264+
v_cache,
265+
block_tables,
266+
sequence_lengths,
267+
)
268+
else:
269+
if not self.use_alibi_attn:
270+
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
271+
copy_k_to_blocked_cache(
272+
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
273+
)
274+
copy_k_to_blocked_cache(
275+
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
276+
)
277+
278+
attn_output = flash_decoding_attention(
279+
q=query_states,
280+
k_cache=k_cache,
281+
v_cache=v_cache,
282+
kv_seq_len=sequence_lengths,
283+
block_tables=block_tables,
284+
block_size=block_size,
285+
max_seq_len_in_batch=kv_seq_len,
286+
output=output_tensor,
287+
mid_output=fd_inter_tensor.mid_output,
288+
mid_output_lse=fd_inter_tensor.mid_output_lse,
289+
alibi_slopes=self.alibi_slopes,
290+
sm_scale=sm_scale,
291+
q_len=q_len,
292+
)
293+
294+
attn_output = attn_output.view(-1, self.hidden_size)
295+
attn_output = torch.mm(attn_output, self.o_proj_weight)
296+
297+
return attn_output
298+
135299

136300
# NOTE This will cause difference as out length increases.
137301
class NopadBaichuanMLP(nn.Module):

colossalai/inference/modeling/policy/nopadding_baichuan.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import torch.nn as nn
22
from torch.nn import Parameter
33

4-
from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
4+
from colossalai.inference.modeling.models.nopadding_baichuan import (
5+
NopadBaichuanAttention,
6+
NopadBaichuanMLP,
7+
baichuan_rmsnorm_forward,
8+
)
59
from colossalai.inference.modeling.models.nopadding_llama import (
610
llama_causal_lm_forward,
711
llama_decoder_layer_forward,
812
llama_model_forward,
9-
llama_rmsnorm_forward,
1013
)
1114
from colossalai.inference.utils import init_to_get_rotary
1215
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
@@ -21,38 +24,40 @@ def module_policy(self):
2124
policy = super().module_policy()
2225

2326
decoder_attribute_replacement = {
24-
"lm_head.weight": Parameter(
25-
nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
26-
),
27+
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
2728
}
2829
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
2930
attribute_replacement=decoder_attribute_replacement,
3031
)
3132

32-
policy["DecoderLayer"] = ModulePolicyDescription(
33-
sub_module_replacement=[
34-
SubModuleReplacementDescription(
35-
suffix="mlp",
36-
target_module=NopadBaichuanMLP,
37-
),
38-
SubModuleReplacementDescription(
39-
suffix="self_attn",
40-
target_module=NopadBaichuanAttention,
41-
),
42-
]
43-
)
33+
# used for relpacing Baichuan 7B/13B decoder layer
34+
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
35+
policy[layer_name] = ModulePolicyDescription(
36+
sub_module_replacement=[
37+
SubModuleReplacementDescription(
38+
suffix="mlp",
39+
target_module=NopadBaichuanMLP,
40+
),
41+
SubModuleReplacementDescription(
42+
suffix="self_attn",
43+
target_module=NopadBaichuanAttention,
44+
),
45+
]
46+
)
47+
48+
self.append_or_create_method_replacement(
49+
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
50+
)
4451

4552
self.append_or_create_method_replacement(
4653
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
4754
)
4855
self.append_or_create_method_replacement(
4956
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
5057
)
58+
5159
self.append_or_create_method_replacement(
52-
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
53-
)
54-
self.append_or_create_method_replacement(
55-
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
60+
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
5661
)
5762

5863
return policy

0 commit comments

Comments
 (0)