22import math
33import numpy as np
44from lightllm .common .basemodel import TransformerLayerWeight
5- from lightllm .common .basemodel .layer_weights .meta_weights import ROWMMWeight , COLMMWeight , NormWeight , CustomMMWeight , FusedMoeWeight , CustomBMMWeight
5+ from lightllm .common .basemodel .layer_weights .meta_weights import (
6+ ROWMMWeight ,
7+ COLMMWeight ,
8+ NormWeight ,
9+ CustomMMWeight ,
10+ FusedMoeWeight ,
11+ CustomBMMWeight ,
12+ )
613from functools import partial
714
815
@@ -19,16 +26,32 @@ def fuse_q_kb(self, A, B):
1926 k_nope_proj_ = k_b_proj_ .unsqueeze (0 )
2027 k_nope_proj_ = k_nope_proj_ .to (torch .float64 )
2128
22- return self ._cuda (torch .matmul (q_nope_proj_ , k_nope_proj_ ).view (- 1 , self .tp_q_head_num_ * self .kv_lora_rank ).transpose (0 , 1 ))
29+ return self ._cuda (
30+ torch .matmul (q_nope_proj_ , k_nope_proj_ ).view (- 1 , self .tp_q_head_num_ * self .kv_lora_rank ).transpose (0 , 1 )
31+ )
32+
2333
2434def fuse_vb_o (self , A , B ):
2535 v_b_proj_ = A .weight
26- o_weight_ = B .weight .transpose (0 , 1 ).view (self .tp_q_head_num_ , self .qk_nope_head_dim , - 1 ).contiguous ().to (self .data_type_ ).cpu ()
27- return self ._cuda (torch .matmul (v_b_proj_ .to (torch .float64 ), o_weight_ .to (torch .float64 )).view (- 1 , self .network_config_ ["hidden_size" ]))
36+ o_weight_ = (
37+ B .weight .transpose (0 , 1 )
38+ .view (self .tp_q_head_num_ , self .qk_nope_head_dim , - 1 )
39+ .contiguous ()
40+ .to (self .data_type_ )
41+ .cpu ()
42+ )
43+ return self ._cuda (
44+ torch .matmul (v_b_proj_ .to (torch .float64 ), o_weight_ .to (torch .float64 )).view (
45+ - 1 , self .network_config_ ["hidden_size" ]
46+ )
47+ )
48+
2849
2950def load_q_rope (self , A , q_weight_ ):
3051 q_split_n_embed_with_rope = A .split_n_embed
31- q_weight_ = q_weight_ [q_split_n_embed_with_rope * self .tp_rank_ : q_split_n_embed_with_rope * (self .tp_rank_ + 1 ), :]
52+ q_weight_ = q_weight_ [
53+ q_split_n_embed_with_rope * self .tp_rank_ : q_split_n_embed_with_rope * (self .tp_rank_ + 1 ), :
54+ ]
3255 q_weight_ = q_weight_ .transpose (0 , 1 ).contiguous ()
3356 q_nope_proj_ , q_rope_proj_ = torch .split (
3457 q_weight_ .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
@@ -37,6 +60,7 @@ def load_q_rope(self, A, q_weight_):
3760 )
3861 return self ._cuda (q_rope_proj_ .reshape (- 1 , self .qk_rope_head_dim * self .tp_q_head_num_ ).transpose (0 , 1 ))
3962
63+
4064def load_kb (self , A , kv_b_proj_ ):
4165 kv_b_proj_ = kv_b_proj_
4266 k_b_proj_ = kv_b_proj_ .view (self .num_attention_heads , self .qk_nope_head_dim * 2 , self .kv_lora_rank )[
@@ -47,22 +71,31 @@ def load_kb(self, A, kv_b_proj_):
4771 return k_b_proj_ .contiguous ().to (self .data_type_ ).cpu ()
4872 return self ._cuda (k_b_proj_ )
4973
74+
5075def load_vb (self , A , kv_b_proj_ ):
5176 kv_b_proj_ = kv_b_proj_
52- v_b_proj_ = kv_b_proj_ .T .view (
53- self .kv_lora_rank ,
54- self .num_attention_heads ,
55- self .qk_nope_head_dim * 2 ,
56- )[:, :, self .qk_nope_head_dim :].transpose (0 , 1 )
57- v_b_proj_ = v_b_proj_ [
58- self .tp_q_head_num_ * self .tp_rank_ : self .tp_q_head_num_ * (self .tp_rank_ + 1 ), :, :
59- ]
77+ v_b_proj_ = kv_b_proj_ .T .view (self .kv_lora_rank , self .num_attention_heads , self .qk_nope_head_dim * 2 ,)[
78+ :, :, self .qk_nope_head_dim :
79+ ].transpose (0 , 1 )
80+ v_b_proj_ = v_b_proj_ [self .tp_q_head_num_ * self .tp_rank_ : self .tp_q_head_num_ * (self .tp_rank_ + 1 ), :, :]
6081 if A .wait_fuse :
6182 return v_b_proj_ .contiguous ().to (self .data_type_ ).cpu ()
6283 return self ._cuda (v_b_proj_ )
6384
85+
6486class Deepseek2TransformerLayerWeight (TransformerLayerWeight ):
65- def __init__ (self , layer_num , tp_rank , world_size , data_type , network_config , mode = [], quant_cfg = None , disable_qk_absorb = False , disable_vo_absorb = False ):
87+ def __init__ (
88+ self ,
89+ layer_num ,
90+ tp_rank ,
91+ world_size ,
92+ data_type ,
93+ network_config ,
94+ mode = [],
95+ quant_cfg = None ,
96+ disable_qk_absorb = False ,
97+ disable_vo_absorb = False ,
98+ ):
6699 super ().__init__ (layer_num , tp_rank , world_size , data_type , network_config , mode , quant_cfg )
67100 self .is_moe = (
68101 self .network_config_ ["n_routed_experts" ] is not None
@@ -86,9 +119,11 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
86119 self .fuse_pairs = {"q_b_proj_&k_b_proj_" : "fuse_qk_weight_" }
87120 if not self .disable_vo_absorb :
88121 self .fuse_pairs ["v_b_proj_&o_weight_" ] = "fuse_vo_weight_"
89- self .fuse_pairs .update ({
90- "gate_proj&up_proj" : "gate_up_proj" ,
91- })
122+ self .fuse_pairs .update (
123+ {
124+ "gate_proj&up_proj" : "gate_up_proj" ,
125+ }
126+ )
92127
93128 self .init_qkvo ()
94129 if self .is_moe :
@@ -115,7 +150,10 @@ def init_qkvo(self):
115150 rope_weight_name = f"model.layers.{ self .layer_num_ } .self_attn.q_proj.weight"
116151 else :
117152 self .q_a_proj_ = ROWMMWeight (
118- f"model.layers.{ self .layer_num_ } .self_attn.q_a_proj.weight" , self .data_type_ , self .q_lora_rank , disable_tp = True
153+ f"model.layers.{ self .layer_num_ } .self_attn.q_a_proj.weight" ,
154+ self .data_type_ ,
155+ self .q_lora_rank ,
156+ disable_tp = True ,
119157 )
120158 self .q_b_proj_ = CustomMMWeight (
121159 f"model.layers.{ self .layer_num_ } .self_attn.q_b_proj.weight" ,
@@ -126,10 +164,7 @@ def init_qkvo(self):
126164 )
127165 rope_weight_name = f"model.layers.{ self .layer_num_ } .self_attn.q_b_proj.weight"
128166 self .q_rope_proj_ = CustomMMWeight (
129- rope_weight_name ,
130- self .data_type_ ,
131- q_split_n_embed_with_rope ,
132- custom_load = partial (load_q_rope , self )
167+ rope_weight_name , self .data_type_ , q_split_n_embed_with_rope , custom_load = partial (load_q_rope , self )
133168 )
134169 self .kv_a_proj_with_mqa_ = ROWMMWeight (
135170 f"model.layers.{ self .layer_num_ } .self_attn.kv_a_proj_with_mqa.weight" ,
@@ -142,50 +177,47 @@ def init_qkvo(self):
142177 self .data_type_ ,
143178 None ,
144179 wait_fuse = not self .disable_qk_absorb ,
145- custom_load = partial (load_kb , self )
180+ custom_load = partial (load_kb , self ),
146181 )
147182 self .v_b_proj_ = CustomBMMWeight (
148183 f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj.weight" ,
149184 self .data_type_ ,
150185 None ,
151186 wait_fuse = not self .disable_vo_absorb ,
152187 custom_load = partial (load_vb , self ),
153- custom_fuse = partial (fuse_vb_o , self )
188+ custom_fuse = partial (fuse_vb_o , self ),
154189 )
155190 self .o_weight_ = COLMMWeight (
156- f"model.layers.{ self .layer_num_ } .self_attn.o_proj.weight" , self .data_type_ , q_split_n_embed , wait_fuse = not self .disable_vo_absorb ,
191+ f"model.layers.{ self .layer_num_ } .self_attn.o_proj.weight" ,
192+ self .data_type_ ,
193+ q_split_n_embed ,
194+ wait_fuse = not self .disable_vo_absorb ,
157195 )
158196
159197 def _load_mlp (self , mlp_prefix , split_inter_size ):
160198 self .gate_proj = ROWMMWeight (
161199 f"{ mlp_prefix } .gate_proj.weight" , self .data_type_ , split_inter_size , wait_fuse = True
162200 )
163- self .up_proj = ROWMMWeight (
164- f"{ mlp_prefix } .up_proj.weight" , self .data_type_ , split_inter_size , wait_fuse = True
165- )
166- self .down_proj = COLMMWeight (
167- f"{ mlp_prefix } .down_proj.weight" , self .data_type_ , split_inter_size
168- )
201+ self .up_proj = ROWMMWeight (f"{ mlp_prefix } .up_proj.weight" , self .data_type_ , split_inter_size , wait_fuse = True )
202+ self .down_proj = COLMMWeight (f"{ mlp_prefix } .down_proj.weight" , self .data_type_ , split_inter_size )
169203
170204 def init_moe (self ):
171205 moe_intermediate_size = self .network_config_ ["moe_intermediate_size" ]
172206 self .moe_gate = ROWMMWeight (
173207 f"model.layers.{ self .layer_num_ } .mlp.gate.weight" , self .data_type_ , moe_intermediate_size , disable_tp = True
174208 )
175- shared_intermediate_size = (
176- moe_intermediate_size * self .network_config_ ["n_shared_experts" ]
177- )
209+ shared_intermediate_size = moe_intermediate_size * self .network_config_ ["n_shared_experts" ]
178210 shared_split_inter_size = shared_intermediate_size // self .world_size_
179211 self ._load_mlp (f"model.layers.{ self .layer_num_ } .mlp.shared_experts" , shared_split_inter_size )
180-
212+
181213 self .experts = FusedMoeWeight (
182214 gate_proj_name = "gate_proj" ,
183215 down_proj_name = "down_proj" ,
184216 up_proj_name = "up_proj" ,
185217 weight_prefix = f"model.layers.{ self .layer_num_ } .mlp.experts" ,
186218 n_routed_experts = self .n_routed_experts ,
187219 split_inter_size = moe_intermediate_size // self .world_size_ ,
188- data_type = self .data_type_
220+ data_type = self .data_type_ ,
189221 )
190222
191223 def init_ffn (self ):
0 commit comments