1
1
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
2
+ import itertools
2
3
import math
3
- from typing import Optional , Tuple
4
+ from typing import List , Optional , Tuple , Union
4
5
5
6
import torch
6
7
import torch .nn as nn
8
+ from torch .distributed import ProcessGroup
7
9
8
10
from colossalai .inference .flash_decoding_utils import FDIntermTensors
11
+ from colossalai .inference .modeling .models .nopadding_llama import NopadLlamaMLP
9
12
from colossalai .kernel .kernel_loader import InferenceOpsLoader
10
13
from colossalai .kernel .triton import (
11
14
context_attention_unpadded ,
16
19
rotary_embedding ,
17
20
)
18
21
from colossalai .logging import get_dist_logger
22
+ from colossalai .shardformer .layer .parallel_module import ParallelModule
23
+ from colossalai .tensor .d_tensor import Layout , distribute_tensor , is_distributed_tensor
24
+
25
+ logger = get_dist_logger (__name__ )
26
+
27
+ try :
28
+ from flash_attn import flash_attn_varlen_func
29
+
30
+ use_flash_attn2 = True
31
+ except ImportError :
32
+ use_flash_attn2 = False
33
+ logger .warning (f"flash_attn2 has not been installed yet, we will use triton flash attn instead." )
19
34
20
35
logger = get_dist_logger (__name__ )
21
36
@@ -78,14 +93,18 @@ def baichuan_rmsnorm_forward(
78
93
return rms_layernorm (hidden_states , self .weight .data , eps , norm_output , residual )
79
94
80
95
81
- class NopadBaichuanAttention (nn . Module ):
96
+ class NopadBaichuanAttention (ParallelModule ):
82
97
def __init__ (
83
98
self ,
84
99
config ,
85
100
attn_qproj_w : torch .Tensor = None ,
86
101
attn_kproj_w : torch .Tensor = None ,
87
102
attn_vproj_w : torch .Tensor = None ,
88
- attn_oproj_w : torch .Tensor = None ,
103
+ attn_oproj : ParallelModule = None ,
104
+ num_heads : int = None ,
105
+ hidden_size : int = None ,
106
+ process_group : ProcessGroup = None ,
107
+ helper_layout : Layout = None ,
89
108
):
90
109
"""This layer will replace the BaichuanAttention.
91
110
@@ -94,51 +113,112 @@ def __init__(
94
113
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
95
114
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
96
115
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
97
- attn_oproj_w (torch.Tensor , optional): The transposed o_proj weight. Defaults to None.
116
+ attn_oproj (Linear1D_Row , optional): The Linear1D_Row o_proj weight. Defaults to None.
98
117
"""
99
- super () .__init__ ()
100
- self .o_proj_weight = attn_oproj_w
118
+ ParallelModule .__init__ (self )
119
+ self .o_proj = attn_oproj
101
120
102
121
self .config = config
103
- self .hidden_size = config . hidden_size
104
- self .num_heads = config . num_attention_heads
122
+ self .num_heads = num_heads
123
+ self .hidden_size = hidden_size
105
124
self .head_dim = self .hidden_size // self .num_heads
125
+ self .process_group = process_group
126
+ qkv_weight_list = [attn_qproj_w .transpose (0 , 1 ), attn_kproj_w .transpose (0 , 1 ), attn_vproj_w .transpose (0 , 1 )]
127
+ self .qkv_weight = nn .Parameter (torch .stack (qkv_weight_list , dim = 0 ))
128
+
129
+ self .helper_layout = helper_layout
130
+
106
131
self .alibi_slopes = None
107
132
self .use_alibi_attn = False
108
- if self .hidden_size == 5120 :
133
+ # Used for Baichuan13B
134
+ if config .hidden_size == 5120 :
135
+ slopes_start = self .process_group .rank () * num_heads
109
136
self .use_alibi_attn = True
110
- self .alibi_slopes = get_alibi_slopes (self .num_heads , device = attn_qproj_w .device )
111
-
112
- qkv_weight_list = [attn_qproj_w , attn_kproj_w , attn_vproj_w ]
113
- self .qkv_weight = torch .stack (qkv_weight_list , dim = 0 )
137
+ self .alibi_slopes = get_alibi_slopes (config .num_attention_heads , device = attn_qproj_w .device )[
138
+ slopes_start : slopes_start + num_heads
139
+ ].contiguous ()
114
140
115
141
@staticmethod
116
- def from_native_module (module : nn .Module , * args , ** kwargs ) -> "NopadBaichuanAttention" :
142
+ def from_native_module (
143
+ module : nn .Module , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , ** kwargs
144
+ ) -> "NopadBaichuanAttention" :
117
145
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
118
146
119
147
Args:
120
148
module (nn.Module): The origin BaichuanAttention layer.
121
149
"""
122
150
123
151
config = module .config
152
+ q_proj_w , k_proj_w , v_proj_w = module .W_pack .weight .view ((module .hidden_size , 3 , - 1 )).transpose (0 , 1 )
124
153
125
- q_proj_w , k_proj_w , v_proj_w = module .W_pack .weight .view ((3 , module .hidden_size , module .hidden_size ))
154
+ attn_qproj_w = q_proj_w
155
+ attn_kproj_w = k_proj_w
156
+ attn_vproj_w = v_proj_w
157
+ attn_oproj = module .o_proj
126
158
127
- attn_qproj_w = q_proj_w .transpose (0 , 1 )
128
- attn_kproj_w = k_proj_w .transpose (0 , 1 )
129
- attn_vproj_w = v_proj_w .transpose (0 , 1 )
130
- attn_oproj_w = module .o_proj .weight .transpose (0 , 1 )
159
+ helper_layout = (
160
+ module .W_pack .weight .dist_layout
161
+ ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
131
162
132
163
attn_layer = NopadBaichuanAttention (
133
164
config = config ,
134
165
attn_qproj_w = attn_qproj_w ,
135
166
attn_kproj_w = attn_kproj_w ,
136
167
attn_vproj_w = attn_vproj_w ,
137
- attn_oproj_w = attn_oproj_w ,
168
+ attn_oproj = attn_oproj ,
169
+ num_heads = module .num_heads ,
170
+ hidden_size = module .hidden_size ,
171
+ process_group = process_group ,
172
+ helper_layout = helper_layout ,
138
173
)
139
174
140
175
return attn_layer
141
176
177
+ def _load_from_state_dict (
178
+ self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
179
+ ):
180
+ for hook in self ._load_state_dict_pre_hooks .values ():
181
+ hook (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
182
+
183
+ persistent_buffers = {k : v for k , v in self ._buffers .items () if k not in self ._non_persistent_buffers_set }
184
+ local_name_params = itertools .chain (self ._parameters .items (), persistent_buffers .items ())
185
+ local_state = {k : v for k , v in local_name_params if v is not None }
186
+
187
+ key = "qkv_weight"
188
+ qkv_w = state_dict [prefix + "W_pack.weight" ]
189
+
190
+ in_features = qkv_w .size (1 )
191
+ out_features = qkv_w .size (0 ) // 3
192
+
193
+ qkv_w .data = qkv_w .view ((3 , out_features , - 1 )).transpose (0 , 1 ).reshape (out_features , in_features * 3 )
194
+
195
+ device_mesh = self .helper_layout .device_mesh
196
+ sharding_spec = self .helper_layout .sharding_spec
197
+ qkv_w = distribute_tensor (qkv_w , device_mesh , sharding_spec )
198
+
199
+ qkv_w = qkv_w .transpose (0 , 1 ).reshape (3 , in_features , - 1 )
200
+ input_param = nn .Parameter (
201
+ qkv_w
202
+ ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
203
+
204
+ param = local_state [key ]
205
+
206
+ try :
207
+ with torch .no_grad ():
208
+ param .copy_ (input_param )
209
+ except Exception as ex :
210
+ error_msgs .append (
211
+ 'While copying the parameter named "{}", '
212
+ "whose dimensions in the model are {} and "
213
+ "whose dimensions in the checkpoint are {}, "
214
+ "an exception occurred : {}." .format (key , param .size (), input_param .size (), ex .args )
215
+ )
216
+
217
+ strict = False # to avoid unexpected_keys
218
+ super ()._load_from_state_dict (
219
+ state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
220
+ )
221
+
142
222
def forward (
143
223
self ,
144
224
hidden_states : torch .Tensor ,
@@ -292,56 +372,38 @@ def forward(
292
372
)
293
373
294
374
attn_output = attn_output .view (- 1 , self .hidden_size )
295
- attn_output = torch . mm (attn_output , self . o_proj_weight )
375
+ attn_output = self . o_proj (attn_output )
296
376
297
377
return attn_output
298
378
379
+ def extra_repr (self ) -> str :
380
+ return f"qkv_weight_proj MergedLinear1D_Col: in_features={ self .qkv_weight .shape [1 ]} x3, out_features={ self .qkv_weight .shape [2 ]} , bias=False"
299
381
300
- # NOTE This will cause difference as out length increases.
301
- class NopadBaichuanMLP (nn .Module ):
302
- def __init__ (
303
- self ,
304
- mlp_gproj_w : torch .Tensor = None ,
305
- mlp_uproj_w : torch .Tensor = None ,
306
- mlp_dproj_w : torch .Tensor = None ,
307
- ):
308
- """This layer will replace the BaichuanAttention.
309
-
310
- Args:
311
- mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
312
- mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
313
- mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
314
- """
315
- super ().__init__ ()
316
- self .gate_up_weight = torch .stack ([mlp_gproj_w , mlp_uproj_w ], dim = 0 )
317
- self .down_proj_weight = mlp_dproj_w
318
382
383
+ # NOTE This will cause difference as out length increases.
384
+ class NopadBaichuanMLP (NopadLlamaMLP ):
319
385
@staticmethod
320
- def from_native_module (module : nn .Module , * args , ** kwargs ) -> nn .Module :
386
+ def from_native_module (
387
+ module : nn .Module , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , ** kwargs
388
+ ) -> ParallelModule :
321
389
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
322
390
323
391
Args:
324
392
module (nn.Module): The origin MLP(Baichuan) layer.
325
393
"""
326
-
327
- mlp_gproj_w = module .gate_proj .weight .transpose (0 , 1 )
328
- mlp_uproj_w = module .up_proj .weight .transpose (0 , 1 )
329
- mlp_dproj_w = module .down_proj .weight .transpose (0 , 1 )
394
+ mlp_gproj_w = module .gate_proj .weight
395
+ assert is_distributed_tensor (
396
+ module .gate_proj .weight
397
+ ), "gate_proj.weight must be dtensor so we could get the layout of the weight"
398
+ mlp_uproj_w = module .up_proj .weight
399
+ mlp_dproj = module .down_proj
330
400
331
401
mlp_layer = NopadBaichuanMLP (
402
+ config = None ,
332
403
mlp_gproj_w = mlp_gproj_w ,
333
404
mlp_uproj_w = mlp_uproj_w ,
334
- mlp_dproj_w = mlp_dproj_w ,
405
+ mlp_dproj = mlp_dproj ,
406
+ process_group = process_group ,
335
407
)
336
408
337
409
return mlp_layer
338
-
339
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
340
- """
341
- Args:
342
- hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
343
- """
344
- hidden_states = hidden_states .expand (2 , - 1 , - 1 )
345
- gate_up_proj_out = torch .bmm (hidden_states , self .gate_up_weight )
346
- act_out = inference_ops .silu_and_mul (gate_up_proj_out )
347
- return torch .mm (act_out , self .down_proj_weight )
0 commit comments