Skip to content

Commit 6020ab2

Browse files
committed
refactor quantization for load quantized weight and add group block gemm
1 parent d7fdff4 commit 6020ab2

File tree

14 files changed

+749
-308
lines changed

14 files changed

+749
-308
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _init_config(self):
9999
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
100100
if self.finetune_config:
101101
self.config["vocab_size"] = self.finetune_config.vocab_size
102+
102103
return
103104

104105
@final
@@ -112,7 +113,7 @@ def _verify_params(self):
112113
return
113114

114115
def _init_quant(self):
115-
self.quant_cfg = Quantcfg(self.config["n_layer"], self.quant_type, self.quant_cfg_path)
116+
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path)
116117
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")
117118

118119
def _init_weights(self):

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
MultiROWMMWeightNoTP,
1111
MultiCOLMMWeight,
1212
ROWBMMWeight,
13-
COLBMMWeight,
1413
MultiCOLMMWeightNoTp,
1514
ROWBMMWeightNoTp,
16-
COLBMMWeightNoTp,
1715
COLMMWeightNoTp,
1816
)
1917
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,39 @@
11
import os
22
import torch
3-
from .base_weight import BaseWeight
4-
from lightllm.utils.dist_utils import get_world_size, get_rank
53
import threading
4+
from typing import Optional, Tuple, List, Dict, Any
5+
from .base_weight import BaseWeight
66
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
7-
7+
from lightllm.common.quantization.quantize_method import QuantizationMethod
8+
from lightllm.utils.dist_utils import get_world_size, get_rank
89
from lightllm.common.vllm_kernel import _custom_ops as ops
910
from lightllm.utils.device_utils import get_current_device_id
1011

1112

1213
class FusedMoeWeight(BaseWeight):
1314
def __init__(
1415
self,
15-
gate_proj_name,
16-
down_proj_name,
17-
up_proj_name,
18-
e_score_correction_bias_name,
19-
weight_prefix,
20-
n_routed_experts,
21-
split_inter_size,
22-
data_type,
23-
network_config,
24-
):
16+
gate_proj_name: str,
17+
down_proj_name: str,
18+
up_proj_name: str,
19+
e_score_correction_bias_name: str,
20+
weight_prefix: str,
21+
n_routed_experts: int,
22+
split_inter_size: int,
23+
data_type: torch.dtype,
24+
network_config: Dict[str, Any],
25+
weight_scale_suffix: Optional[str] = None,
26+
act_scale_suffix: Optional[str] = None,
27+
) -> None:
2528
super().__init__()
2629
self.w1_weight_name = gate_proj_name
2730
self.w2_weight_name = down_proj_name
2831
self.w3_weight_name = up_proj_name
32+
self.weight_scale_suffix = weight_scale_suffix
33+
self.act_scale_suffix = act_scale_suffix
34+
self.quantized_weight = weight_scale_suffix is not None
35+
self.static_activation = act_scale_suffix is not None
36+
2937
self.e_score_correction_bias_name = e_score_correction_bias_name
3038
self.weight_prefix = weight_prefix
3139
self.n_routed_experts = n_routed_experts
@@ -34,15 +42,23 @@ def __init__(
3442
self.tp_rank_ = get_rank()
3543
self.experts_up_projs = [None] * self.n_routed_experts
3644
self.experts_gate_projs = [None] * self.n_routed_experts
45+
self.experts_up_proj_scales = [None] * self.n_routed_experts
46+
self.experts_gate_proj_scales = [None] * self.n_routed_experts
3747
self.expert_gate_up_proj_etp = None
3848
self.expert_down_proj_etp = None
3949
self.e_score_correction_bias = None
4050
self.w2_list = [None] * self.n_routed_experts
51+
self.w2_scale_list = [None] * self.n_routed_experts
4152
self.quant_method = None
4253
self.scoring_func = network_config["scoring_func"]
54+
self.w1 = [None, None] # weight, weight_scale
55+
self.w2 = [None, None] # weight, weight_scale
4356
self.lock = threading.Lock()
4457

45-
def set_quant_method(self, quant_method):
58+
def set_quant_method(self, quant_method: QuantizationMethod) -> None:
59+
if self.quantized_weight:
60+
self.quant_method = quant_method
61+
return
4662
if isinstance(quant_method, vLLMFP8w8a8QuantizationMethod):
4763
self.quant_method = quant_method
4864
if self.quant_method is not None:
@@ -82,6 +98,8 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
8298
return
8399

84100
def _fuse(self):
101+
if self.quantized_weight:
102+
self._fuse_weight_scale()
85103
with self.lock:
86104
if (
87105
hasattr(self, "experts_up_projs")
@@ -98,21 +116,48 @@ def _fuse(self):
98116
w1_list.append(expert_gate_up_proj)
99117

100118
inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1]
101-
self.w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size)
119+
w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size)
102120
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
103-
self.w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(
104-
len(self.w2_list), inter_shape, hidden_size
105-
)
106-
if self.quant_method is not None:
107-
self.w1 = self.quant_method.quantize(self.w1)
108-
self.w2 = self.quant_method.quantize(self.w2)
121+
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
122+
if not self.quantized_weight and self.quant_method is not None:
123+
self.w1 = self.quant_method.quantize(w1)
124+
self.w2 = self.quant_method.quantize(w2)
109125
else:
110-
self.w1 = [self._cuda(self.w1), None]
111-
self.w2 = [self._cuda(self.w2), None]
126+
self.w1[0] = self._cuda(w1)
127+
self.w2[0] = self._cuda(w2)
112128
delattr(self, "w2_list")
113129
delattr(self, "experts_up_projs")
114130
delattr(self, "experts_gate_projs")
115131

132+
def _fuse_weight_scale(self):
133+
with self.lock:
134+
if (
135+
hasattr(self, "experts_up_proj_scales")
136+
and None not in self.experts_up_proj_scales
137+
and None not in self.experts_gate_proj_scales
138+
and None not in self.w2_scale_list
139+
):
140+
w1_scale_list = []
141+
for i_experts in range(self.n_routed_experts):
142+
expert_gate_up_proj_scale = torch.cat(
143+
[self.experts_gate_proj_scales[i_experts], self.experts_up_proj_scales[i_experts]], dim=0
144+
)
145+
w1_scale_list.append(expert_gate_up_proj_scale)
146+
147+
inter_shape, hidden_size = w1_scale_list[0].shape[0], w1_scale_list[0].shape[1]
148+
w1_scale = torch._utils._flatten_dense_tensors(w1_scale_list).view(
149+
len(w1_scale_list), inter_shape, hidden_size
150+
)
151+
inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1]
152+
w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view(
153+
len(self.w2_scale_list), inter_shape, hidden_size
154+
)
155+
self.w1[1] = self._cuda(w1_scale)
156+
self.w2[1] = self._cuda(w2_scale)
157+
delattr(self, "w2_scale_list")
158+
delattr(self, "experts_up_proj_scales")
159+
delattr(self, "experts_gate_proj_scales")
160+
116161
def _load_hf_weights_etp(self, weights):
117162
world_size_ = get_world_size()
118163
assert self.n_routed_experts % world_size_ == 0
@@ -196,11 +241,51 @@ def load_hf_weights(self, weights):
196241
self.w2_list[i_experts] = weights[w2_weight][
197242
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
198243
]
199-
244+
if self.quant_method is not None:
245+
self._load_weight_scale(weights)
200246
self._fuse()
201247

248+
def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None:
249+
block_size = 1
250+
if hasattr(self.quant_method, "block_size"):
251+
block_size = self.quant_method.block_size
252+
for i_experts in range(self.n_routed_experts):
253+
w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}"
254+
w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}"
255+
w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}"
256+
if w1_scale in weights:
257+
self.experts_gate_proj_scales[i_experts] = weights[w1_scale][
258+
self.split_inter_size
259+
// block_size
260+
* self.tp_rank_ : self.split_inter_size
261+
// block_size
262+
* (self.tp_rank_ + 1),
263+
:,
264+
]
265+
if w3_scale in weights:
266+
self.experts_up_proj_scales[i_experts] = weights[w3_scale][
267+
self.split_inter_size
268+
// block_size
269+
* self.tp_rank_ : self.split_inter_size
270+
// block_size
271+
* (self.tp_rank_ + 1),
272+
:,
273+
]
274+
275+
if w2_scale in weights:
276+
self.w2_scale_list[i_experts] = weights[w2_scale][
277+
:,
278+
self.split_inter_size
279+
// block_size
280+
* self.tp_rank_ : self.split_inter_size
281+
// block_size
282+
* (self.tp_rank_ + 1),
283+
]
284+
202285
def _cuda(self, cpu_tensor):
203286
device_id = get_current_device_id()
287+
if self.quantized_weight:
288+
return cpu_tensor.contiguous().cuda(device_id)
204289
return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id)
205290

206291
def verify_load(self):

0 commit comments

Comments
 (0)