Skip to content

Commit 5f00002

Browse files
authored
[Inference] Adapt Baichuan2-13B TP (#5659)
* adapt to baichuan2 13B * add baichuan2 13B TP * update baichuan tp logic * rm unused code * Fix TP logic * fix alibi slopes tp logic * rm nn.Module * Polished the code. * change BAICHUAN_MODEL_NAME_OR_PATH * Modified the logic for loading Baichuan weights. * fix typos
1 parent 808ee6e commit 5f00002

File tree

7 files changed

+280
-98
lines changed

7 files changed

+280
-98
lines changed

colossalai/inference/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
_DEFAULT_PROMPT_TEMPLATES = {
2828
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
29-
"baichuan": "<reserved_106>{input_text}<reserved_107>",
29+
"baichuan": " <reserved_106> {input_text} <reserved_107> ",
3030
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
3131
}
3232

colossalai/inference/core/engine.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,23 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
112112
model_policy (Policy): the policy to replace the model
113113
"""
114114

115+
casuallm = None
115116
if isinstance(model_or_path, str):
116117
try:
117118
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
118119
arch = getattr(hf_config, "architectures")[0]
119-
model = _supported_models[arch](hf_config)
120+
if arch in _supported_models.keys():
121+
casuallm = _supported_models[arch](hf_config)
122+
if isinstance(casuallm, AutoModelForCausalLM):
123+
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
124+
model = (
125+
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
126+
)
127+
else:
128+
model = _supported_models[arch](hf_config)
129+
else:
130+
raise ValueError(f"Model {arch} is not supported.")
131+
120132
except Exception as e:
121133
self.logger.error(
122134
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@@ -164,7 +176,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
164176
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
165177
)
166178

167-
if isinstance(model_or_path, str):
179+
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
168180
from colossalai.inference.core.plugin import InferCheckpoint_io
169181

170182
cpt_io = InferCheckpoint_io()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List, Union
2+
3+
import torch.nn as nn
4+
from torch.distributed import ProcessGroup
5+
6+
from colossalai.shardformer.layer import Linear1D_Col
7+
from colossalai.shardformer.layer.parallel_module import ParallelModule
8+
9+
10+
class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
11+
@staticmethod
12+
def from_native_module(
13+
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
14+
) -> ParallelModule:
15+
module.in_features = module.weight.size(1)
16+
module.out_features = module.weight.size(0)
17+
module.bias = None
18+
module.weight.data = nn.functional.normalize(module.weight)
19+
20+
return Linear1D_Col.from_native_module(
21+
module,
22+
process_group,
23+
*args,
24+
**kwargs,
25+
)
26+
27+
28+
class BaichuanWpackLinear1D_Col(Linear1D_Col):
29+
@staticmethod
30+
def from_native_module(
31+
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
32+
) -> ParallelModule:
33+
in_features = module.in_features * 3
34+
out_features = module.out_features // 3
35+
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
36+
module.bias = None
37+
38+
return Linear1D_Col.from_native_module(
39+
module,
40+
process_group,
41+
*args,
42+
**kwargs,
43+
)

colossalai/inference/modeling/models/nopadding_baichuan.py

Lines changed: 117 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
2+
import itertools
23
import math
3-
from typing import Optional, Tuple
4+
from typing import List, Optional, Tuple, Union
45

56
import torch
67
import torch.nn as nn
8+
from torch.distributed import ProcessGroup
79

810
from colossalai.inference.flash_decoding_utils import FDIntermTensors
11+
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
912
from colossalai.kernel.kernel_loader import InferenceOpsLoader
1013
from colossalai.kernel.triton import (
1114
context_attention_unpadded,
@@ -16,6 +19,18 @@
1619
rotary_embedding,
1720
)
1821
from colossalai.logging import get_dist_logger
22+
from colossalai.shardformer.layer.parallel_module import ParallelModule
23+
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
24+
25+
logger = get_dist_logger(__name__)
26+
27+
try:
28+
from flash_attn import flash_attn_varlen_func
29+
30+
use_flash_attn2 = True
31+
except ImportError:
32+
use_flash_attn2 = False
33+
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
1934

2035
logger = get_dist_logger(__name__)
2136

@@ -78,14 +93,18 @@ def baichuan_rmsnorm_forward(
7893
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)
7994

8095

81-
class NopadBaichuanAttention(nn.Module):
96+
class NopadBaichuanAttention(ParallelModule):
8297
def __init__(
8398
self,
8499
config,
85100
attn_qproj_w: torch.Tensor = None,
86101
attn_kproj_w: torch.Tensor = None,
87102
attn_vproj_w: torch.Tensor = None,
88-
attn_oproj_w: torch.Tensor = None,
103+
attn_oproj: ParallelModule = None,
104+
num_heads: int = None,
105+
hidden_size: int = None,
106+
process_group: ProcessGroup = None,
107+
helper_layout: Layout = None,
89108
):
90109
"""This layer will replace the BaichuanAttention.
91110
@@ -94,51 +113,112 @@ def __init__(
94113
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
95114
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
96115
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
97-
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
116+
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
98117
"""
99-
super().__init__()
100-
self.o_proj_weight = attn_oproj_w
118+
ParallelModule.__init__(self)
119+
self.o_proj = attn_oproj
101120

102121
self.config = config
103-
self.hidden_size = config.hidden_size
104-
self.num_heads = config.num_attention_heads
122+
self.num_heads = num_heads
123+
self.hidden_size = hidden_size
105124
self.head_dim = self.hidden_size // self.num_heads
125+
self.process_group = process_group
126+
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
127+
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
128+
129+
self.helper_layout = helper_layout
130+
106131
self.alibi_slopes = None
107132
self.use_alibi_attn = False
108-
if self.hidden_size == 5120:
133+
# Used for Baichuan13B
134+
if config.hidden_size == 5120:
135+
slopes_start = self.process_group.rank() * num_heads
109136
self.use_alibi_attn = True
110-
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
111-
112-
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
113-
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
137+
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
138+
slopes_start : slopes_start + num_heads
139+
].contiguous()
114140

115141
@staticmethod
116-
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
142+
def from_native_module(
143+
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
144+
) -> "NopadBaichuanAttention":
117145
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
118146
119147
Args:
120148
module (nn.Module): The origin BaichuanAttention layer.
121149
"""
122150

123151
config = module.config
152+
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)
124153

125-
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
154+
attn_qproj_w = q_proj_w
155+
attn_kproj_w = k_proj_w
156+
attn_vproj_w = v_proj_w
157+
attn_oproj = module.o_proj
126158

127-
attn_qproj_w = q_proj_w.transpose(0, 1)
128-
attn_kproj_w = k_proj_w.transpose(0, 1)
129-
attn_vproj_w = v_proj_w.transpose(0, 1)
130-
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
159+
helper_layout = (
160+
module.W_pack.weight.dist_layout
161+
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
131162

132163
attn_layer = NopadBaichuanAttention(
133164
config=config,
134165
attn_qproj_w=attn_qproj_w,
135166
attn_kproj_w=attn_kproj_w,
136167
attn_vproj_w=attn_vproj_w,
137-
attn_oproj_w=attn_oproj_w,
168+
attn_oproj=attn_oproj,
169+
num_heads=module.num_heads,
170+
hidden_size=module.hidden_size,
171+
process_group=process_group,
172+
helper_layout=helper_layout,
138173
)
139174

140175
return attn_layer
141176

177+
def _load_from_state_dict(
178+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
179+
):
180+
for hook in self._load_state_dict_pre_hooks.values():
181+
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
182+
183+
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
184+
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
185+
local_state = {k: v for k, v in local_name_params if v is not None}
186+
187+
key = "qkv_weight"
188+
qkv_w = state_dict[prefix + "W_pack.weight"]
189+
190+
in_features = qkv_w.size(1)
191+
out_features = qkv_w.size(0) // 3
192+
193+
qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)
194+
195+
device_mesh = self.helper_layout.device_mesh
196+
sharding_spec = self.helper_layout.sharding_spec
197+
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)
198+
199+
qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
200+
input_param = nn.Parameter(
201+
qkv_w
202+
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
203+
204+
param = local_state[key]
205+
206+
try:
207+
with torch.no_grad():
208+
param.copy_(input_param)
209+
except Exception as ex:
210+
error_msgs.append(
211+
'While copying the parameter named "{}", '
212+
"whose dimensions in the model are {} and "
213+
"whose dimensions in the checkpoint are {}, "
214+
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
215+
)
216+
217+
strict = False # to avoid unexpected_keys
218+
super()._load_from_state_dict(
219+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
220+
)
221+
142222
def forward(
143223
self,
144224
hidden_states: torch.Tensor,
@@ -292,56 +372,38 @@ def forward(
292372
)
293373

294374
attn_output = attn_output.view(-1, self.hidden_size)
295-
attn_output = torch.mm(attn_output, self.o_proj_weight)
375+
attn_output = self.o_proj(attn_output)
296376

297377
return attn_output
298378

379+
def extra_repr(self) -> str:
380+
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
299381

300-
# NOTE This will cause difference as out length increases.
301-
class NopadBaichuanMLP(nn.Module):
302-
def __init__(
303-
self,
304-
mlp_gproj_w: torch.Tensor = None,
305-
mlp_uproj_w: torch.Tensor = None,
306-
mlp_dproj_w: torch.Tensor = None,
307-
):
308-
"""This layer will replace the BaichuanAttention.
309-
310-
Args:
311-
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
312-
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
313-
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
314-
"""
315-
super().__init__()
316-
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
317-
self.down_proj_weight = mlp_dproj_w
318382

383+
# NOTE This will cause difference as out length increases.
384+
class NopadBaichuanMLP(NopadLlamaMLP):
319385
@staticmethod
320-
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
386+
def from_native_module(
387+
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
388+
) -> ParallelModule:
321389
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
322390
323391
Args:
324392
module (nn.Module): The origin MLP(Baichuan) layer.
325393
"""
326-
327-
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
328-
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
329-
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
394+
mlp_gproj_w = module.gate_proj.weight
395+
assert is_distributed_tensor(
396+
module.gate_proj.weight
397+
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
398+
mlp_uproj_w = module.up_proj.weight
399+
mlp_dproj = module.down_proj
330400

331401
mlp_layer = NopadBaichuanMLP(
402+
config=None,
332403
mlp_gproj_w=mlp_gproj_w,
333404
mlp_uproj_w=mlp_uproj_w,
334-
mlp_dproj_w=mlp_dproj_w,
405+
mlp_dproj=mlp_dproj,
406+
process_group=process_group,
335407
)
336408

337409
return mlp_layer
338-
339-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
340-
"""
341-
Args:
342-
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
343-
"""
344-
hidden_states = hidden_states.expand(2, -1, -1)
345-
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
346-
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
347-
return torch.mm(act_out, self.down_proj_weight)

0 commit comments

Comments
 (0)