Skip to content

Commit 3d9c1bc

Browse files
committed
add cli/env
1 parent 6ea84e8 commit 3d9c1bc

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def generate_scale_name(name):
1010
return weight_scale_name, input_scale_name
1111

1212

13-
QUANTED_WEIGHT = os.getenv("QUANTED_WEIGHT", "0").upper() in ["1", "TRUE", "ON"]
13+
STATIC_QUANT = os.getenv("STATIC_QUANT", "0").upper() in ["1", "TRUE", "ON"]
1414

1515

1616
class MMWeightTpl(BaseWeightTpl):
@@ -43,7 +43,7 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
4343

4444
def _post_load_weights(self):
4545
if self.quant_method is not None:
46-
if QUANTED_WEIGHT:
46+
if STATIC_QUANT:
4747
if all(w is not None for w in [self.weight, self.weight_scale, self.input_scale]):
4848
self.weight = self.quant_method.quantize((self.weight, self.weight_scale, self.input_scale))
4949
else:
@@ -86,11 +86,11 @@ def load_hf_weights(self, weights):
8686
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
8787
self.bias = bias.cuda(self.tp_rank_)
8888

89-
if QUANTED_WEIGHT and self.weight_scale_name in weights:
89+
if STATIC_QUANT and self.weight_scale_name in weights:
9090
weight_scale = weights[self.weight_scale_name].to(torch.float)[self.start : self.end]
9191
self.weight_scale = weight_scale.cuda()
9292

93-
if QUANTED_WEIGHT and self.input_scale_name in weights:
93+
if STATIC_QUANT and self.input_scale_name in weights:
9494
input_scale = weights[self.input_scale_name].to(torch.float)
9595
self.input_scale = input_scale.cuda()
9696

@@ -122,11 +122,11 @@ def load_hf_weights(self, weights):
122122
bias = weights[self.bias_name]
123123
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.tp_rank_)
124124

125-
if QUANTED_WEIGHT and self.weight_scale_name in weights:
125+
if STATIC_QUANT and self.weight_scale_name in weights:
126126
weight_scale = weights[self.weight_scale_name].to(torch.float)
127127
self.weight_scale = weight_scale.cuda()
128128

129-
if QUANTED_WEIGHT and self.input_scale_name in weights:
129+
if STATIC_QUANT and self.input_scale_name in weights:
130130
input_scale = weights[self.input_scale_name].to(torch.float)
131131
self.input_scale = input_scale.cuda()
132132

@@ -203,10 +203,10 @@ def load_hf_weights(self, weights):
203203
if self.has_bias and self.bias_names[i] in weights:
204204
bias = weights[self.bias_names[i]].to(self.data_type_)
205205
self.biases[i] = bias[self.starts[i] : self.ends[i]]
206-
if QUANTED_WEIGHT and self.weight_scale_names[i] in weights:
206+
if STATIC_QUANT and self.weight_scale_names[i] in weights:
207207
weight_scale = weights[self.weight_scale_names[i]][self.starts[i] : self.ends[i]]
208208
self.weight_scales[i] = weight_scale.to(torch.float)
209-
if QUANTED_WEIGHT and self.input_scale_names[i] in weights:
209+
if STATIC_QUANT and self.input_scale_names[i] in weights:
210210
input_scale = weights[self.input_scale_names[i]].to(torch.float)
211211
self.input_scales[i] = input_scale
212212

@@ -234,10 +234,10 @@ def load_hf_weights(self, weights):
234234
if self.has_bias and self.bias_names[i] in weights:
235235
bias = weights[self.bias_names[i]].to(self.data_type_)
236236
self.biases[i] = bias[:, self.starts[i] : self.ends[i]]
237-
if QUANTED_WEIGHT and self.weight_scale_names[i] in weights:
237+
if STATIC_QUANT and self.weight_scale_names[i] in weights:
238238
weight_scale = weights[self.weight_scale_names[i]]
239239
self.weight_scales[i] = weight_scale.to(torch.float)
240-
if QUANTED_WEIGHT and self.input_scale_names[i] in weights:
240+
if STATIC_QUANT and self.input_scale_names[i] in weights:
241241
input_scale = weights[self.input_scale_names[i]].to(torch.float)
242242
self.input_scales[i] = input_scale
243243
self._fuse()

lightllm/server/api_cli.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import os
12
import argparse
23

34

5+
def push_env(args):
6+
if args.static_quant:
7+
os.environ["STATIC_QUANT"] = "1"
8+
9+
410
def make_argument_parser() -> argparse.ArgumentParser:
511
parser = argparse.ArgumentParser()
612

@@ -244,4 +250,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
244250
help="""Path of quantization config. It can be used for mixed quantization.
245251
Examples can be found in lightllm/common/quantization/configs.""",
246252
)
253+
parser.add_argument(
254+
"--static_quant",
255+
action="store_true",
256+
help="whether to load static quantized weights. Currently, only vllm-w8a8 is supported.",
257+
)
247258
return parser

lightllm/server/api_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
3838
from fastapi.responses import Response, StreamingResponse, JSONResponse
3939
import uvicorn
40-
from .api_cli import make_argument_parser
40+
from .api_cli import make_argument_parser, push_env
4141
from .sampling_params import SamplingParams
4242
from .multimodal_params import MultimodalParams
4343
from .httpserver.manager import HttpServerManager
@@ -390,6 +390,7 @@ async def startup_event():
390390
torch.multiprocessing.set_start_method("spawn"), # this code will not be ok for settings to fork to subprocess
391391
parser = make_argument_parser()
392392
args = parser.parse_args()
393+
push_env(args)
393394
g_objs.args = args
394395
from .api_start import normal_or_p_d_start, pd_master_start
395396

lightllm/server/api_start.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def normal_or_p_d_start(g_objs):
5050
args.mem_fraction > 0 and args.mem_fraction < 1
5151
), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1."
5252

53+
if args.static_quant:
54+
assert args.quant_type == "vllm-w8a8", "Only static parameter loading for vllm-w8a8 is supported."
55+
5356
# splitfuse_mode 和 cuda_graph 不能同时开启
5457
if args.splitfuse_mode:
5558
assert args.disable_cudagraph

0 commit comments

Comments
 (0)