11import os
22import torch
3- from .base_weight import BaseWeight
4- from lightllm .utils .dist_utils import get_world_size , get_rank
53import threading
4+ from typing import Optional , Tuple , List , Dict , Any
5+ from .base_weight import BaseWeight
66from lightllm .common .quantization import vLLMFP8w8a8QuantizationMethod
7-
7+ from lightllm .common .quantization .quantize_method import QuantizationMethod
8+ from lightllm .utils .dist_utils import get_world_size , get_rank
89from lightllm .common .vllm_kernel import _custom_ops as ops
910from lightllm .utils .device_utils import get_current_device_id
1011
1112
1213class FusedMoeWeight (BaseWeight ):
1314 def __init__ (
1415 self ,
15- gate_proj_name ,
16- down_proj_name ,
17- up_proj_name ,
18- e_score_correction_bias_name ,
19- weight_prefix ,
20- n_routed_experts ,
21- split_inter_size ,
22- data_type ,
23- network_config ,
24- ):
16+ gate_proj_name : str ,
17+ down_proj_name : str ,
18+ up_proj_name : str ,
19+ e_score_correction_bias_name : str ,
20+ weight_prefix : str ,
21+ n_routed_experts : int ,
22+ split_inter_size : int ,
23+ data_type : torch .dtype ,
24+ network_config : Dict [str , Any ],
25+ weight_scale_suffix : Optional [str ] = None ,
26+ act_scale_suffix : Optional [str ] = None ,
27+ ) -> None :
2528 super ().__init__ ()
2629 self .w1_weight_name = gate_proj_name
2730 self .w2_weight_name = down_proj_name
2831 self .w3_weight_name = up_proj_name
32+ self .weight_scale_suffix = weight_scale_suffix
33+ self .act_scale_suffix = act_scale_suffix
34+ self .quantized_weight = weight_scale_suffix is not None
35+ self .static_activation = act_scale_suffix is not None
36+
2937 self .e_score_correction_bias_name = e_score_correction_bias_name
3038 self .weight_prefix = weight_prefix
3139 self .n_routed_experts = n_routed_experts
@@ -34,15 +42,23 @@ def __init__(
3442 self .tp_rank_ = get_rank ()
3543 self .experts_up_projs = [None ] * self .n_routed_experts
3644 self .experts_gate_projs = [None ] * self .n_routed_experts
45+ self .experts_up_proj_scales = [None ] * self .n_routed_experts
46+ self .experts_gate_proj_scales = [None ] * self .n_routed_experts
3747 self .expert_gate_up_proj_etp = None
3848 self .expert_down_proj_etp = None
3949 self .e_score_correction_bias = None
4050 self .w2_list = [None ] * self .n_routed_experts
51+ self .w2_scale_list = [None ] * self .n_routed_experts
4152 self .quant_method = None
4253 self .scoring_func = network_config ["scoring_func" ]
54+ self .w1 = [None , None ] # weight, weight_scale
55+ self .w2 = [None , None ] # weight, weight_scale
4356 self .lock = threading .Lock ()
4457
45- def set_quant_method (self , quant_method ):
58+ def set_quant_method (self , quant_method : QuantizationMethod ) -> None :
59+ if self .quantized_weight :
60+ self .quant_method = quant_method
61+ return
4662 if isinstance (quant_method , vLLMFP8w8a8QuantizationMethod ):
4763 self .quant_method = quant_method
4864 if self .quant_method is not None :
@@ -82,6 +98,8 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
8298 return
8399
84100 def _fuse (self ):
101+ if self .quantized_weight :
102+ self ._fuse_weight_scale ()
85103 with self .lock :
86104 if (
87105 hasattr (self , "experts_up_projs" )
@@ -98,21 +116,48 @@ def _fuse(self):
98116 w1_list .append (expert_gate_up_proj )
99117
100118 inter_shape , hidden_size = w1_list [0 ].shape [0 ], w1_list [0 ].shape [1 ]
101- self . w1 = torch ._utils ._flatten_dense_tensors (w1_list ).view (len (w1_list ), inter_shape , hidden_size )
119+ w1 = torch ._utils ._flatten_dense_tensors (w1_list ).view (len (w1_list ), inter_shape , hidden_size )
102120 inter_shape , hidden_size = self .w2_list [0 ].shape [0 ], self .w2_list [0 ].shape [1 ]
103- self .w2 = torch ._utils ._flatten_dense_tensors (self .w2_list ).view (
104- len (self .w2_list ), inter_shape , hidden_size
105- )
106- if self .quant_method is not None :
107- self .w1 = self .quant_method .quantize (self .w1 )
108- self .w2 = self .quant_method .quantize (self .w2 )
121+ w2 = torch ._utils ._flatten_dense_tensors (self .w2_list ).view (len (self .w2_list ), inter_shape , hidden_size )
122+ if not self .quantized_weight and self .quant_method is not None :
123+ self .w1 = self .quant_method .quantize (w1 )
124+ self .w2 = self .quant_method .quantize (w2 )
109125 else :
110- self .w1 = [ self ._cuda (self . w1 ), None ]
111- self .w2 = [ self ._cuda (self . w2 ), None ]
126+ self .w1 [ 0 ] = self ._cuda (w1 )
127+ self .w2 [ 0 ] = self ._cuda (w2 )
112128 delattr (self , "w2_list" )
113129 delattr (self , "experts_up_projs" )
114130 delattr (self , "experts_gate_projs" )
115131
132+ def _fuse_weight_scale (self ):
133+ with self .lock :
134+ if (
135+ hasattr (self , "experts_up_proj_scales" )
136+ and None not in self .experts_up_proj_scales
137+ and None not in self .experts_gate_proj_scales
138+ and None not in self .w2_scale_list
139+ ):
140+ w1_scale_list = []
141+ for i_experts in range (self .n_routed_experts ):
142+ expert_gate_up_proj_scale = torch .cat (
143+ [self .experts_gate_proj_scales [i_experts ], self .experts_up_proj_scales [i_experts ]], dim = 0
144+ )
145+ w1_scale_list .append (expert_gate_up_proj_scale )
146+
147+ inter_shape , hidden_size = w1_scale_list [0 ].shape [0 ], w1_scale_list [0 ].shape [1 ]
148+ w1_scale = torch ._utils ._flatten_dense_tensors (w1_scale_list ).view (
149+ len (w1_scale_list ), inter_shape , hidden_size
150+ )
151+ inter_shape , hidden_size = self .w2_scale_list [0 ].shape [0 ], self .w2_scale_list [0 ].shape [1 ]
152+ w2_scale = torch ._utils ._flatten_dense_tensors (self .w2_scale_list ).view (
153+ len (self .w2_scale_list ), inter_shape , hidden_size
154+ )
155+ self .w1 [1 ] = self ._cuda (w1_scale )
156+ self .w2 [1 ] = self ._cuda (w2_scale )
157+ delattr (self , "w2_scale_list" )
158+ delattr (self , "experts_up_proj_scales" )
159+ delattr (self , "experts_gate_proj_scales" )
160+
116161 def _load_hf_weights_etp (self , weights ):
117162 world_size_ = get_world_size ()
118163 assert self .n_routed_experts % world_size_ == 0
@@ -196,11 +241,51 @@ def load_hf_weights(self, weights):
196241 self .w2_list [i_experts ] = weights [w2_weight ][
197242 :, self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 )
198243 ]
199-
244+ if self .quant_method is not None :
245+ self ._load_weight_scale (weights )
200246 self ._fuse ()
201247
248+ def _load_weight_scale (self , weights : Dict [str , torch .Tensor ]) -> None :
249+ block_size = 1
250+ if hasattr (self .quant_method , "block_size" ):
251+ block_size = self .quant_method .block_size
252+ for i_experts in range (self .n_routed_experts ):
253+ w1_scale = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .{ self .weight_scale_suffix } "
254+ w2_scale = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .{ self .weight_scale_suffix } "
255+ w3_scale = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .{ self .weight_scale_suffix } "
256+ if w1_scale in weights :
257+ self .experts_gate_proj_scales [i_experts ] = weights [w1_scale ][
258+ self .split_inter_size
259+ // block_size
260+ * self .tp_rank_ : self .split_inter_size
261+ // block_size
262+ * (self .tp_rank_ + 1 ),
263+ :,
264+ ]
265+ if w3_scale in weights :
266+ self .experts_up_proj_scales [i_experts ] = weights [w3_scale ][
267+ self .split_inter_size
268+ // block_size
269+ * self .tp_rank_ : self .split_inter_size
270+ // block_size
271+ * (self .tp_rank_ + 1 ),
272+ :,
273+ ]
274+
275+ if w2_scale in weights :
276+ self .w2_scale_list [i_experts ] = weights [w2_scale ][
277+ :,
278+ self .split_inter_size
279+ // block_size
280+ * self .tp_rank_ : self .split_inter_size
281+ // block_size
282+ * (self .tp_rank_ + 1 ),
283+ ]
284+
202285 def _cuda (self , cpu_tensor ):
203286 device_id = get_current_device_id ()
287+ if self .quantized_weight :
288+ return cpu_tensor .contiguous ().cuda (device_id )
204289 return cpu_tensor .contiguous ().to (self .data_type_ ).cuda (device_id )
205290
206291 def verify_load (self ):
0 commit comments