Skip to content

Commit 1cff518

Browse files
HDCharlesIsotr0py
authored andcommitted
[refactor] CTMoEMethods to use QuantizationArgs (vllm-project#28871)
Signed-off-by: HDCharles <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent a2d83d6 commit 1cff518

File tree

2 files changed

+86
-75
lines changed

2 files changed

+86
-75
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,10 @@ def get_scheme_dict(
767767
targets=self.target_scheme_map.keys(),
768768
fused_mapping=self.packed_modules_mapping,
769769
)
770-
771-
return self.target_scheme_map[matched_target]
770+
scheme_dict = self.target_scheme_map[matched_target]
771+
if scheme_dict.get("format") is None:
772+
scheme_dict["format"] = self.quant_format
773+
return scheme_dict
772774

773775
return None
774776

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 82 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
import torch
99
from compressed_tensors import CompressionFormat
10-
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
10+
from compressed_tensors.quantization import (
11+
ActivationOrdering,
12+
QuantizationArgs,
13+
QuantizationStrategy,
14+
)
1115
from torch.nn.parameter import Parameter
1216

1317
import vllm.envs as envs
@@ -142,10 +146,26 @@ def get_moe_method(
142146
# are supported + check if the layer is being ignored.
143147
weight_quant = scheme_dict.get("weights")
144148
input_quant = scheme_dict.get("input_activations")
149+
format = scheme_dict.get("format")
145150

146151
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
147152
# group_size=None means channelwise
148153
group_size = weight_quant.group_size or -1
154+
155+
valid_format_and_bits = (
156+
weight_quant.num_bits in WNA16_SUPPORTED_BITS
157+
and format == CompressionFormat.pack_quantized.value
158+
)
159+
160+
if not valid_format_and_bits:
161+
raise ValueError(
162+
"For Fused MoE layers, only format: ",
163+
f"{CompressionFormat.pack_quantized.value} ",
164+
f" and bits: {WNA16_SUPPORTED_BITS} is supported ",
165+
f"but got format: {CompressionFormat.pack_quantized.value} "
166+
f" and bits: {weight_quant.num_bits}",
167+
)
168+
149169
# Prefer to use the MarlinMoE kernel when it is supported.
150170
if (
151171
not check_moe_marlin_supports_layer(layer, group_size)
@@ -161,12 +181,12 @@ def get_moe_method(
161181
)
162182
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
163183
return CompressedTensorsWNA16MoEMethod(
164-
quant_config, layer.moe_config, layer_name
184+
weight_quant, input_quant, layer.moe_config
165185
)
166186
else:
167187
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
168188
return CompressedTensorsWNA16MarlinMoEMethod(
169-
quant_config, layer.moe_config, layer_name
189+
weight_quant, input_quant, layer.moe_config
170190
)
171191
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
172192
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
@@ -176,15 +196,15 @@ def get_moe_method(
176196
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
177197
):
178198
return CompressedTensorsW8A8Fp8MoEMethod(
179-
quant_config, layer.moe_config, layer_name
199+
weight_quant, input_quant, layer.moe_config
180200
)
181201
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
182202
return CompressedTensorsW8A8Int8MoEMethod(
183-
quant_config, layer.moe_config, layer_name
203+
weight_quant, input_quant, layer.moe_config
184204
)
185205
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
186206
return CompressedTensorsW4A8Int8MoEMethod(
187-
quant_config, layer.moe_config, layer_name
207+
weight_quant, input_quant, layer.moe_config
188208
)
189209
else:
190210
raise RuntimeError(
@@ -651,17 +671,19 @@ def apply(
651671
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
652672
def __init__(
653673
self,
654-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
674+
weight_quant: QuantizationArgs,
675+
input_quant: QuantizationArgs,
655676
moe: FusedMoEConfig,
656677
layer_name: str | None = None,
657678
):
658-
super().__init__(moe)
659-
self.quant_config = quant_config
660-
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
661-
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
662-
"input_activations"
679+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
680+
CompressedTensorsConfig,
663681
)
664682

683+
super().__init__(moe)
684+
self.weight_quant = weight_quant
685+
self.input_quant = input_quant
686+
665687
per_tensor = (
666688
self.weight_quant.strategy == QuantizationStrategy.TENSOR
667689
and self.input_quant.strategy == QuantizationStrategy.TENSOR
@@ -699,11 +721,13 @@ def __init__(
699721
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
700722

701723
# cutlass path
702-
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
724+
self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
703725
self.weight_quant, self.input_quant
704726
)
705727
self.use_cutlass = not self.block_quant and (
706-
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
728+
CompressedTensorsConfig._is_fp8_w8a8_sm90(
729+
self.weight_quant, self.input_quant
730+
)
707731
or self.is_fp8_w8a8_sm100
708732
)
709733
self.disable_expert_map = False
@@ -1263,16 +1287,14 @@ def supports_eplb(self) -> bool:
12631287
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
12641288
def __init__(
12651289
self,
1266-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
1290+
weight_quant: QuantizationArgs,
1291+
input_quant: QuantizationArgs,
12671292
moe: FusedMoEConfig,
12681293
layer_name: str | None = None,
12691294
):
12701295
super().__init__(moe)
1271-
self.quant_config = quant_config
1272-
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
1273-
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
1274-
"input_activations"
1275-
)
1296+
self.weight_quant = weight_quant
1297+
self.input_quant = input_quant
12761298

12771299
per_channel = (
12781300
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@@ -1417,36 +1439,27 @@ def apply(
14171439
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
14181440
def __init__(
14191441
self,
1420-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
1442+
weight_quant: QuantizationArgs,
1443+
input_quant: QuantizationArgs | None,
14211444
moe: FusedMoEConfig,
14221445
layer_name: str | None = None,
14231446
):
14241447
super().__init__(moe)
1425-
self.quant_config = quant_config
1426-
# TODO: @dsikka: refactor this to use schemes as other kernels
1427-
# are supported + check if the layer is being ignored.
1428-
config = self.quant_config.target_scheme_map["Linear"].get("weights")
1429-
self.num_bits = config.num_bits
1430-
self.packed_factor = 32 // config.num_bits
1431-
self.strategy = config.strategy
1432-
self.group_size = config.group_size
1433-
self.actorder = config.actorder
1434-
self.layer_name = layer_name
1435-
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
1436-
assert config.symmetric, "Only symmetric quantization is supported for MoE"
1448+
self.weight_quant = weight_quant
1449+
self.input_quant = input_quant
1450+
assert weight_quant.symmetric, (
1451+
"Only symmetric quantization is supported for MoE"
1452+
)
1453+
# Extract properties from weight_quant
1454+
self.num_bits = weight_quant.num_bits
1455+
self.packed_factor = 32 // weight_quant.num_bits
1456+
self.strategy = weight_quant.strategy
1457+
self.group_size = weight_quant.group_size
1458+
self.actorder = weight_quant.actorder
14371459

1438-
if not (
1439-
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
1440-
and self.num_bits in WNA16_SUPPORTED_BITS
1441-
):
1442-
raise ValueError(
1443-
"For Fused MoE layers, only ",
1444-
f"{CompressionFormat.pack_quantized.value} ",
1445-
"is supported for the following bits: ",
1446-
f"{WNA16_SUPPORTED_BITS}",
1447-
)
14481460
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
14491461
self.use_marlin = True
1462+
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
14501463

14511464
def create_weights(
14521465
self,
@@ -1816,35 +1829,26 @@ def apply(
18161829
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
18171830
def __init__(
18181831
self,
1819-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
1832+
weight_quant: QuantizationArgs,
1833+
input_quant: QuantizationArgs | None,
18201834
moe: FusedMoEConfig,
18211835
layer_name: str | None = None,
18221836
):
18231837
super().__init__(moe)
1824-
self.quant_config = quant_config
1825-
# TODO: @dsikka: refactor this to use schemes as other kernels
1826-
# are supported + check if the layer is being ignored.
1827-
config = self.quant_config.target_scheme_map["Linear"].get("weights")
1828-
self.num_bits = config.num_bits
1829-
self.packed_factor = 32 // config.num_bits
1830-
self.strategy = config.strategy
1838+
self.weight_quant = weight_quant
1839+
self.input_quant = input_quant
1840+
# Extract properties from weight_quant
1841+
self.num_bits = weight_quant.num_bits
1842+
self.packed_factor = 32 // weight_quant.num_bits
1843+
self.strategy = weight_quant.strategy
18311844
# channelwise is not supported by this kernel
1832-
assert config.strategy == "group"
1833-
self.group_size = config.group_size
1845+
assert weight_quant.strategy == "group"
1846+
self.group_size = weight_quant.group_size
18341847
# grouped actorder isn't supported by this kernel
1835-
assert config.actorder != "group"
1836-
assert config.symmetric, "Only symmetric quantization is supported for MoE"
1837-
1838-
if not (
1839-
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
1840-
and self.num_bits in WNA16_SUPPORTED_BITS
1841-
):
1842-
raise ValueError(
1843-
"For Fused MoE layers, only ",
1844-
f"{CompressionFormat.pack_quantized.value} ",
1845-
"is supported for the following bits: ",
1846-
f"{WNA16_SUPPORTED_BITS}",
1847-
)
1848+
assert weight_quant.actorder != "group"
1849+
assert weight_quant.symmetric, (
1850+
"Only symmetric quantization is supported for MoE"
1851+
)
18481852

18491853
def create_weights(
18501854
self,
@@ -2070,28 +2074,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
20702074

20712075
def __init__(
20722076
self,
2073-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
2077+
weight_quant: QuantizationArgs,
2078+
input_quant: QuantizationArgs,
20742079
moe: FusedMoEConfig,
20752080
layer_name: str | None = None,
20762081
):
20772082
super().__init__(moe)
20782083
self.has_bias = self.moe.has_bias
2079-
self.quant_config = quant_config
2084+
self.weight_quant = weight_quant
2085+
self.input_quant = input_quant
20802086

20812087
# Validate scheme: weights=W4 (channel or group),
20822088
# activations=dynamic TOKEN (A8)
2083-
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
2084-
aq = self.quant_config.target_scheme_map["Linear"].get("input_activations")
20852089

20862090
# Must be dynamic per-token activations
2087-
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
2091+
if (
2092+
input_quant.strategy != QuantizationStrategy.TOKEN
2093+
or not input_quant.dynamic
2094+
):
20882095
raise ValueError(
20892096
"W4A8-int MoE needs dynamic per-token activation quantization."
20902097
)
20912098

20922099
# Weight can be channel-wise (group_size=None) or group-wise
2093-
self.group_size = wq.group_size if (wq.group_size is not None) else -1
2094-
if wq.num_bits != 4:
2100+
self.group_size = (
2101+
weight_quant.group_size if (weight_quant.group_size is not None) else -1
2102+
)
2103+
if weight_quant.num_bits != 4:
20952104
raise ValueError("This method only supports 4-bit weights (num_bits=4).")
20962105

20972106
# CPU only

0 commit comments

Comments
 (0)