Skip to content

Commit 4fcca13

Browse files
authored
[main][Feature] Support Qwen3 W4A8 quantization (vllm-project#2060)
### What this PR does / why we need it? Adding `W4A8_DYNAMIC` quantization support for linear. Dense models like Qwen3 can infer with `W4A8_DYNAMIC` quantization. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` Adding e2e case in `tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC` to test qwen3 w4a8_dynamic quantized model Note the w4a8_dynamic quantized model is quantized by `msit/msmodelslim` of commit `d0abb0a47e1f1a473b866ad41b737fbc28fb1409` 1. Generate `W4A8_DYNAMIC` quantization weights using `msmodelslim` ```shell git clone https://gitee.com/ascend/msit.git cd msit/msmodelslim git checkout d0abb0a47e1f1a473b866ad41b737fbc28fb1409 bash install.sh ``` 2. Serve model using `vllm` ```shell VLLM_USE_V1=1 python -m vllm.entrypoints.openai.api_server \ --model vllm-ascend/Qwen3-8B-W4A8 \ --port 8000 \ --quantization ascend \ --tensor_parallel_size 2 \ --enforce-eager ``` - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@4cd7fe6 --------- Signed-off-by: ZhouXiang <[email protected]>
1 parent 6874d66 commit 4fcca13

File tree

8 files changed

+185
-1
lines changed

8 files changed

+185
-1
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ jobs:
278278
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
279279
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
280280
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
281+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
281282
pytest -sv tests/e2e/multicard/test_data_parallel.py
282283
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
283284
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,22 @@ def test_models_distributed_Qwen3_W8A8():
166166
with VllmRunner(
167167
snapshot_download("vllm-ascend/Qwen3-8B-W8A8"),
168168
max_model_len=8192,
169-
enforce_eager=True,
169+
dtype="auto",
170+
tensor_parallel_size=2,
171+
quantization="ascend",
172+
) as vllm_model:
173+
vllm_model.generate_greedy(example_prompts, max_tokens)
174+
175+
176+
def test_models_distributed_Qwen3_W4A8DYNAMIC():
177+
example_prompts = [
178+
"Hello, my name is",
179+
]
180+
max_tokens = 5
181+
182+
with VllmRunner(
183+
snapshot_download("vllm-ascend/Qwen3-8B-W4A8"),
184+
max_model_len=8192,
170185
dtype="auto",
171186
tensor_parallel_size=2,
172187
quantization="ascend",
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
3+
from tests.ut.base import TestBase
4+
from vllm_ascend.quantization.w4a8_dynamic import AscendW4A8DynamicLinearMethod
5+
6+
7+
class TestAscendW4A8DynamicLinearMethod(TestBase):
8+
9+
def setUp(self):
10+
self.method = AscendW4A8DynamicLinearMethod()
11+
self.method.group_size = 8
12+
13+
def test_get_weight(self):
14+
weight = self.method.get_weight(8, 32, torch.bfloat16)
15+
self.assertEqual(weight["weight"].dtype, torch.int8)
16+
self.assertEqual(weight["weight"].shape, (32, 8))
17+
18+
def test_get_pergroup_param(self):
19+
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
20+
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
21+
self.assertEqual(params["weight_scale"].shape, (32, 1))
22+
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
23+
self.assertEqual(params["weight_offset"].shape, (32, 1))
24+
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
25+
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
26+
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
27+
self.assertEqual(params["weight_offset_second"].shape, (32, 1))

vllm_ascend/quantization/quant_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,17 @@ def create_weights(
205205
layer.register_parameter(perchannel_name, param)
206206
set_weight_attrs(param, extra_weight_attrs)
207207

208+
pergroup_dict = self.quant_method.get_pergroup_param(
209+
input_size_per_partition, output_size_per_partition, params_dtype)
210+
for pergroup_name, pergroup_param in pergroup_dict.items():
211+
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
212+
set_weight_attrs(param, {"output_dim": 0})
213+
layer.register_parameter(pergroup_name, param)
214+
set_weight_attrs(param, extra_weight_attrs)
215+
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
216+
setattr(param, "input_dim", 1)
217+
param.input_dim = 1
218+
208219
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
209220
if hasattr(self.quant_method, "process_weights_after_loading"):
210221
self.quant_method.process_weights_after_loading(layer)

vllm_ascend/quantization/quantizer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.logger import logger
2424

2525
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
26+
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
2627
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
2728
AscendW8A8LinearMethod)
2829
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
@@ -263,6 +264,13 @@ def get_quantizer(cls,
263264
f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}")
264265

265266

267+
class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
268+
269+
@staticmethod
270+
def build_linear_method():
271+
return AscendW4A8DynamicLinearMethod()
272+
273+
266274
class W8A8Quantizer(VLLMAscendQuantizer):
267275

268276
@staticmethod
@@ -290,6 +298,7 @@ def build_moe_method():
290298

291299

292300
SUPPORT_ASCEND_QUANTIZER_TYPE = {
301+
"W4A8_DYNAMIC": W4A8DYNAMICQuantizer,
293302
"W8A8": W8A8Quantizer,
294303
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
295304
"C8": W8A8Quantizer,
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Any, Dict, Optional
19+
20+
import torch
21+
import torch_npu
22+
from vllm.config import get_current_vllm_config
23+
24+
25+
class AscendW4A8DynamicLinearMethod:
26+
"""Linear method for Ascend W4A8_DYNAMIC
27+
"""
28+
29+
def __init__(self):
30+
self.transpose_weight = True
31+
try:
32+
self.group_size = get_current_vllm_config(
33+
).quant_config.quant_description.get("group_size", 256)
34+
except AttributeError:
35+
self.group_size = 256
36+
37+
@staticmethod
38+
def get_weight(input_size: int, output_size: int,
39+
params_dtype: torch.dtype) -> Dict[str, Any]:
40+
params_dict = {
41+
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
42+
}
43+
return params_dict
44+
45+
@staticmethod
46+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
47+
return {}
48+
49+
@staticmethod
50+
def get_perchannel_param(output_size: int,
51+
params_dtype: torch.dtype) -> Dict[str, Any]:
52+
return {}
53+
54+
def get_pergroup_param(self, input_size: int, output_size: int,
55+
params_dtype: torch.dtype) -> Dict[str, Any]:
56+
params_dict = {}
57+
params_dict["weight_scale"] = torch.empty(output_size,
58+
1,
59+
dtype=params_dtype)
60+
params_dict["weight_offset"] = torch.empty(output_size,
61+
1,
62+
dtype=params_dtype)
63+
params_dict["weight_scale_second"] = torch.empty(output_size,
64+
input_size //
65+
self.group_size,
66+
dtype=params_dtype)
67+
params_dict["weight_offset_second"] = torch.empty(output_size,
68+
input_size //
69+
self.group_size,
70+
dtype=params_dtype)
71+
return params_dict
72+
73+
@staticmethod
74+
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
75+
per_group_scale: torch.Tensor):
76+
k, n = weight.shape
77+
group_num, n = per_group_scale.shape
78+
weight_high = weight.to(torch.float32).reshape(
79+
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
80+
weight_high = weight_high.reshape(k, n)
81+
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
82+
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
83+
return antiquant_scale.npu(), bias
84+
85+
def apply(
86+
self,
87+
layer: torch.nn.Module,
88+
x: torch.Tensor,
89+
bias: Optional[torch.Tensor] = None,
90+
tp_rank: Optional[int] = None,
91+
) -> torch.Tensor:
92+
return torch_npu.npu_weight_quant_batchmatmul(
93+
x,
94+
layer.weight,
95+
antiquant_scale=layer.weight_scale_second.to(x.dtype),
96+
antiquant_group_size=self.group_size,
97+
)
98+
99+
def process_weights_after_loading(self, layer: torch.nn.Module):
100+
if self.transpose_weight:
101+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
102+
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
103+
torch.float32)
104+
layer.weight_offset.data = layer.weight_offset.data.flatten()
105+
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
106+
layer.weight.data,
107+
layer.weight_scale.data,
108+
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
109+
)
110+
param = torch.nn.Parameter(scale_bias, requires_grad=False)
111+
layer.register_parameter("weight_scale_bias", param)
112+
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
113+
layer.weight.data.to(torch.int32))

vllm_ascend/quantization/w8a8.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def get_perchannel_param(
8484
dtype=params_dtype)
8585
return params_dict
8686

87+
def get_pergroup_param(self, input_size: int, output_size: int,
88+
params_dtype: torch.dtype) -> Dict[str, Any]:
89+
return {}
90+
8791
@staticmethod
8892
def apply(
8993
layer: torch.nn.Module,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,10 @@ def get_perchannel_param(
699699
dtype=params_dtype)
700700
return params_dict
701701

702+
def get_pergroup_param(self, input_size: int, output_size: int,
703+
params_dtype: torch.dtype) -> Dict[str, Any]:
704+
return {}
705+
702706
@staticmethod
703707
def apply(
704708
layer: torch.nn.Module,

0 commit comments

Comments
 (0)