Skip to content

Commit adccb9d

Browse files
authored
Static quant for vllm-w8a8 (#659)
1 parent 5b1ff80 commit adccb9d

File tree

4 files changed

+104
-12
lines changed

4 files changed

+104
-12
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
import os
12
import torch
23
from .base_weight import BaseWeightTpl
34
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
45

56

7+
def generate_scale_name(name):
8+
weight_scale_name = ".".join(name.split(".")[:-1] + ["weight_scale"])
9+
input_scale_name = ".".join(name.split(".")[:-1] + ["input_scale"])
10+
return weight_scale_name, input_scale_name
11+
12+
13+
STATIC_QUANT = os.getenv("STATIC_QUANT", "0").upper() in ["1", "TRUE", "ON"]
14+
15+
616
class MMWeightTpl(BaseWeightTpl):
717
def __init__(self, data_type):
818
super().__init__()
919
self.data_type_ = data_type
1020
self.quant_method = None
1121
self.weight = None
1222
self.bias = None
23+
self.weight_scale = None
24+
self.input_scale = None
1325

1426
def set_quant_method(self, quant_method):
1527
self.quant_method = quant_method
@@ -31,7 +43,11 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
3143

3244
def _post_load_weights(self):
3345
if self.quant_method is not None:
34-
self.weight = self.quant_method.quantize(self.weight.cuda(self.tp_rank_))
46+
if STATIC_QUANT:
47+
if all(w is not None for w in [self.weight, self.weight_scale, self.input_scale]):
48+
self.weight = self.quant_method.quantize((self.weight, self.weight_scale, self.input_scale))
49+
else:
50+
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.tp_rank_))
3551
return
3652
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
3753

@@ -43,6 +59,7 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
4359
self.end = split_n_embed * (self.tp_rank_ + 1)
4460
self.weight_name = weight_name
4561
self.bias_name = bias_name
62+
self.weight_scale_name, self.input_scale_name = generate_scale_name(weight_name)
4663

4764
def verify_load(self):
4865
load_ok = True
@@ -60,13 +77,24 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
6077

6178
def load_hf_weights(self, weights):
6279
weight = None
80+
weight_scale = None
81+
input_scale = None
6382
if self.weight_name in weights:
64-
weight = weights[self.weight_name].to(self.data_type_)
83+
weight = weights[self.weight_name]
6584
self.weight = weight[self.start : self.end]
6685
if self.bias_name in weights:
6786
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
6887
self.bias = bias.cuda(self.tp_rank_)
69-
if weight is None:
88+
89+
if STATIC_QUANT and self.weight_scale_name in weights:
90+
weight_scale = weights[self.weight_scale_name].to(torch.float)[self.start : self.end]
91+
self.weight_scale = weight_scale.cuda()
92+
93+
if STATIC_QUANT and self.input_scale_name in weights:
94+
input_scale = weights[self.input_scale_name].to(torch.float)
95+
self.input_scale = input_scale.cuda()
96+
97+
if weight is None and weight_scale is None and input_scale is None:
7098
return
7199
self._post_load_weights()
72100
return
@@ -85,13 +113,24 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
85113

86114
def load_hf_weights(self, weights):
87115
weight = None
116+
weight_scale = None
117+
input_scale = None
88118
if self.weight_name in weights:
89-
weight = weights[self.weight_name].to(self.data_type_)
119+
weight = weights[self.weight_name]
90120
self.weight = weight[:, self.start : self.end]
91121
if self.bias_name in weights:
92122
bias = weights[self.bias_name]
93123
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.tp_rank_)
94-
if weight is None:
124+
125+
if STATIC_QUANT and self.weight_scale_name in weights:
126+
weight_scale = weights[self.weight_scale_name].to(torch.float)
127+
self.weight_scale = weight_scale.cuda()
128+
129+
if STATIC_QUANT and self.input_scale_name in weights:
130+
input_scale = weights[self.input_scale_name].to(torch.float)
131+
self.input_scale = input_scale.cuda()
132+
133+
if weight is None and weight_scale is None and input_scale is None:
95134
return
96135
self._post_load_weights()
97136
return
@@ -109,8 +148,17 @@ def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]):
109148
self.ends = [i * (self.tp_rank_ + 1) for i in self.split_n_embeds]
110149
self.weight_names = weight_names
111150
self.bias_names = bias_names
151+
self.weight_scale_names = []
152+
self.input_scale_names = []
153+
for weight_name in weight_names:
154+
weight_scale_name, input_scale_name = generate_scale_name(weight_name)
155+
self.weight_scale_names.append(weight_scale_name)
156+
self.input_scale_names.append(input_scale_name)
157+
112158
self.weights = [None] * len(self.weight_names)
113159
self.biases = [None] * len(self.bias_names)
160+
self.input_scales = [None] * len(self.weight_names)
161+
self.weight_scales = [None] * len(self.weight_names)
114162
self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0
115163

116164
def verify_load(self):
@@ -131,6 +179,16 @@ def _fuse(self):
131179
if self.weight is None and all(w is not None for w in self.weights):
132180
self.weight = torch.cat(self.weights, dim=0)
133181
self._post_load_weights()
182+
183+
if self.weight_scale is None and all(w is not None for w in self.weight_scales):
184+
self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda()
185+
self._post_load_weights()
186+
187+
if self.input_scale is None and all(w is not None for w in self.input_scales):
188+
input_scales = torch.stack(self.input_scales, dim=0)
189+
self.input_scale = torch.max(input_scales).cuda()
190+
self._post_load_weights()
191+
134192
if self.has_bias:
135193
if self.bias is None and all(b is not None for b in self.biases):
136194
self.bias = torch.cat(self.biases, dim=0).cuda(self.tp_rank_)
@@ -140,11 +198,18 @@ def load_hf_weights(self, weights):
140198
weight = None
141199
for i in range(len(self.weight_names)):
142200
if self.weight_names[i] in weights:
143-
weight = weights[self.weight_names[i]].to(self.data_type_)
201+
weight = weights[self.weight_names[i]]
144202
self.weights[i] = weight[self.starts[i] : self.ends[i]]
145203
if self.has_bias and self.bias_names[i] in weights:
146204
bias = weights[self.bias_names[i]].to(self.data_type_)
147205
self.biases[i] = bias[self.starts[i] : self.ends[i]]
206+
if STATIC_QUANT and self.weight_scale_names[i] in weights:
207+
weight_scale = weights[self.weight_scale_names[i]][self.starts[i] : self.ends[i]]
208+
self.weight_scales[i] = weight_scale.to(torch.float)
209+
if STATIC_QUANT and self.input_scale_names[i] in weights:
210+
input_scale = weights[self.input_scale_names[i]].to(torch.float)
211+
self.input_scales[i] = input_scale
212+
148213
self._fuse()
149214
return
150215

@@ -164,11 +229,17 @@ def load_hf_weights(self, weights):
164229
weight = None
165230
for i in range(len(self.weight_names)):
166231
if self.weight_names[i] in weights:
167-
weight = weights[self.weight_names[i]].to(self.data_type_)
232+
weight = weights[self.weight_names[i]]
168233
self.weights[i] = weight[:, self.starts[i] : self.ends[i]]
169234
if self.has_bias and self.bias_names[i] in weights:
170235
bias = weights[self.bias_names[i]].to(self.data_type_)
171236
self.biases[i] = bias[:, self.starts[i] : self.ends[i]]
237+
if STATIC_QUANT and self.weight_scale_names[i] in weights:
238+
weight_scale = weights[self.weight_scale_names[i]]
239+
self.weight_scales[i] = weight_scale.to(torch.float)
240+
if STATIC_QUANT and self.input_scale_names[i] in weights:
241+
input_scale = weights[self.input_scale_names[i]].to(torch.float)
242+
self.input_scales[i] = input_scale
172243
self._fuse()
173244
return
174245

lightllm/common/quantization/vllm_quant.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,31 @@ def __init__(self):
3131
super().__init__()
3232

3333
def quantize(self, weight: torch.Tensor):
34-
if hasattr(weight, "scale"):
35-
return weight.data.transpose(0, 1).cuda(), weight.scale.cuda()
34+
if isinstance(weight, tuple):
35+
return (weight[0].transpose(0, 1).cuda(),) + weight[1:]
3636
weight = weight.float()
3737
scale = weight.abs().max(dim=-1)[0] / 127
3838
weight = weight.transpose(0, 1) / scale.reshape(1, -1)
3939
weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8)
4040
return weight.cuda(), scale.cuda()
4141

4242
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
43-
x_q, x_scale, x_zp = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True)
43+
input_scale = None
44+
if len(weights) == 3:
45+
qweight, weight_scale, input_scale = weights
46+
elif len(weights) == 2:
47+
qweight, weight_scale = weights
48+
else:
49+
raise ValueError("vllm-quant Weights must be a tuple of length 2 or 3.")
50+
51+
x_q, x_scale, x_zp = ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True)
4452
m = input_tensor.shape[0]
45-
n = weights[0].shape[1]
53+
n = qweight.shape[1]
4654
if out is None:
4755
out = g_cache_manager.alloc_tensor(
4856
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
4957
)
50-
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
58+
torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
5159
return out
5260

5361

lightllm/server/api_cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
244244
help="""Path of quantization config. It can be used for mixed quantization.
245245
Examples can be found in lightllm/common/quantization/configs.""",
246246
)
247+
parser.add_argument(
248+
"--static_quant",
249+
action="store_true",
250+
help="whether to load static quantized weights. Currently, only vllm-w8a8 is supported.",
251+
)
247252
return parser

lightllm/server/api_start.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
logger = init_logger(__name__)
1919

20+
def set_env(args):
21+
if args.static_quant:
22+
os.environ["STATIC_QUANT"] = "1"
2023

2124
def normal_or_p_d_start(g_objs):
2225
from .api_server import G_Objs
@@ -45,11 +48,16 @@ def normal_or_p_d_start(g_objs):
4548

4649
logger.info(f"use tgi api: {args.use_tgi_api}")
4750

51+
set_env(args)
52+
4853
assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache"
4954
assert (
5055
args.mem_fraction > 0 and args.mem_fraction < 1
5156
), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1."
5257

58+
if args.static_quant:
59+
assert args.quant_type == "vllm-w8a8", "Only static parameter loading for vllm-w8a8 is supported."
60+
5361
# splitfuse_mode 和 cuda_graph 不能同时开启
5462
if args.splitfuse_mode:
5563
assert args.disable_cudagraph

0 commit comments

Comments
 (0)