77
88import torch
99from compressed_tensors import CompressionFormat
10- from compressed_tensors .quantization import ActivationOrdering , QuantizationStrategy
10+ from compressed_tensors .quantization import (
11+ ActivationOrdering ,
12+ QuantizationArgs ,
13+ QuantizationStrategy ,
14+ )
1115from torch .nn .parameter import Parameter
1216
1317import vllm .envs as envs
@@ -142,10 +146,26 @@ def get_moe_method(
142146 # are supported + check if the layer is being ignored.
143147 weight_quant = scheme_dict .get ("weights" )
144148 input_quant = scheme_dict .get ("input_activations" )
149+ format = scheme_dict .get ("format" )
145150
146151 if quant_config ._is_wNa16_group_channel (weight_quant , input_quant ):
147152 # group_size=None means channelwise
148153 group_size = weight_quant .group_size or - 1
154+
155+ valid_format_and_bits = (
156+ weight_quant .num_bits in WNA16_SUPPORTED_BITS
157+ and format == CompressionFormat .pack_quantized .value
158+ )
159+
160+ if not valid_format_and_bits :
161+ raise ValueError (
162+ "For Fused MoE layers, only format: " ,
163+ f"{ CompressionFormat .pack_quantized .value } " ,
164+ f" and bits: { WNA16_SUPPORTED_BITS } is supported " ,
165+ f"but got format: { CompressionFormat .pack_quantized .value } "
166+ f" and bits: { weight_quant .num_bits } " ,
167+ )
168+
149169 # Prefer to use the MarlinMoE kernel when it is supported.
150170 if (
151171 not check_moe_marlin_supports_layer (layer , group_size )
@@ -161,12 +181,12 @@ def get_moe_method(
161181 )
162182 logger .info_once ("Using CompressedTensorsWNA16MoEMethod" )
163183 return CompressedTensorsWNA16MoEMethod (
164- quant_config , layer .moe_config , layer_name
184+ weight_quant , input_quant , layer .moe_config
165185 )
166186 else :
167187 logger .info_once ("Using CompressedTensorsWNA16MarlinMoEMethod" )
168188 return CompressedTensorsWNA16MarlinMoEMethod (
169- quant_config , layer .moe_config , layer_name
189+ weight_quant , input_quant , layer .moe_config
170190 )
171191 elif quant_config ._is_fp4a4_nvfp4 (weight_quant , input_quant ):
172192 return CompressedTensorsW4A4Nvfp4MoEMethod (layer .moe_config , layer_name )
@@ -176,15 +196,15 @@ def get_moe_method(
176196 or quant_config ._is_fp8_w8a8 (weight_quant , input_quant )
177197 ):
178198 return CompressedTensorsW8A8Fp8MoEMethod (
179- quant_config , layer .moe_config , layer_name
199+ weight_quant , input_quant , layer .moe_config
180200 )
181201 elif quant_config ._is_dynamic_token_w8a8 (weight_quant , input_quant ):
182202 return CompressedTensorsW8A8Int8MoEMethod (
183- quant_config , layer .moe_config , layer_name
203+ weight_quant , input_quant , layer .moe_config
184204 )
185205 elif quant_config ._is_dynamic_token_w4a8_int (weight_quant , input_quant ):
186206 return CompressedTensorsW4A8Int8MoEMethod (
187- quant_config , layer .moe_config , layer_name
207+ weight_quant , input_quant , layer .moe_config
188208 )
189209 else :
190210 raise RuntimeError (
@@ -651,17 +671,19 @@ def apply(
651671class CompressedTensorsW8A8Fp8MoEMethod (CompressedTensorsMoEMethod ):
652672 def __init__ (
653673 self ,
654- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
674+ weight_quant : QuantizationArgs ,
675+ input_quant : QuantizationArgs ,
655676 moe : FusedMoEConfig ,
656677 layer_name : str | None = None ,
657678 ):
658- super ().__init__ (moe )
659- self .quant_config = quant_config
660- self .weight_quant = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
661- self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
662- "input_activations"
679+ from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import ( # noqa: E501
680+ CompressedTensorsConfig ,
663681 )
664682
683+ super ().__init__ (moe )
684+ self .weight_quant = weight_quant
685+ self .input_quant = input_quant
686+
665687 per_tensor = (
666688 self .weight_quant .strategy == QuantizationStrategy .TENSOR
667689 and self .input_quant .strategy == QuantizationStrategy .TENSOR
@@ -699,11 +721,13 @@ def __init__(
699721 self .rocm_aiter_moe_enabled = rocm_aiter_ops .is_fused_moe_enabled ()
700722
701723 # cutlass path
702- self .is_fp8_w8a8_sm100 = quant_config ._is_fp8_w8a8_sm100 (
724+ self .is_fp8_w8a8_sm100 = CompressedTensorsConfig ._is_fp8_w8a8_sm100 (
703725 self .weight_quant , self .input_quant
704726 )
705727 self .use_cutlass = not self .block_quant and (
706- quant_config ._is_fp8_w8a8_sm90 (self .weight_quant , self .input_quant )
728+ CompressedTensorsConfig ._is_fp8_w8a8_sm90 (
729+ self .weight_quant , self .input_quant
730+ )
707731 or self .is_fp8_w8a8_sm100
708732 )
709733 self .disable_expert_map = False
@@ -1263,16 +1287,14 @@ def supports_eplb(self) -> bool:
12631287class CompressedTensorsW8A8Int8MoEMethod (CompressedTensorsMoEMethod ):
12641288 def __init__ (
12651289 self ,
1266- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
1290+ weight_quant : QuantizationArgs ,
1291+ input_quant : QuantizationArgs ,
12671292 moe : FusedMoEConfig ,
12681293 layer_name : str | None = None ,
12691294 ):
12701295 super ().__init__ (moe )
1271- self .quant_config = quant_config
1272- self .weight_quant = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
1273- self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
1274- "input_activations"
1275- )
1296+ self .weight_quant = weight_quant
1297+ self .input_quant = input_quant
12761298
12771299 per_channel = (
12781300 self .weight_quant .strategy == QuantizationStrategy .CHANNEL
@@ -1417,36 +1439,27 @@ def apply(
14171439class CompressedTensorsWNA16MarlinMoEMethod (CompressedTensorsMoEMethod ):
14181440 def __init__ (
14191441 self ,
1420- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
1442+ weight_quant : QuantizationArgs ,
1443+ input_quant : QuantizationArgs | None ,
14211444 moe : FusedMoEConfig ,
14221445 layer_name : str | None = None ,
14231446 ):
14241447 super ().__init__ (moe )
1425- self .quant_config = quant_config
1426- # TODO: @dsikka: refactor this to use schemes as other kernels
1427- # are supported + check if the layer is being ignored.
1428- config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
1429- self .num_bits = config .num_bits
1430- self .packed_factor = 32 // config .num_bits
1431- self .strategy = config .strategy
1432- self .group_size = config .group_size
1433- self .actorder = config .actorder
1434- self .layer_name = layer_name
1435- self .marlin_input_dtype = get_marlin_input_dtype (layer_name )
1436- assert config .symmetric , "Only symmetric quantization is supported for MoE"
1448+ self .weight_quant = weight_quant
1449+ self .input_quant = input_quant
1450+ assert weight_quant .symmetric , (
1451+ "Only symmetric quantization is supported for MoE"
1452+ )
1453+ # Extract properties from weight_quant
1454+ self .num_bits = weight_quant .num_bits
1455+ self .packed_factor = 32 // weight_quant .num_bits
1456+ self .strategy = weight_quant .strategy
1457+ self .group_size = weight_quant .group_size
1458+ self .actorder = weight_quant .actorder
14371459
1438- if not (
1439- self .quant_config .quant_format == CompressionFormat .pack_quantized .value
1440- and self .num_bits in WNA16_SUPPORTED_BITS
1441- ):
1442- raise ValueError (
1443- "For Fused MoE layers, only " ,
1444- f"{ CompressionFormat .pack_quantized .value } " ,
1445- "is supported for the following bits: " ,
1446- f"{ WNA16_SUPPORTED_BITS } " ,
1447- )
14481460 self .quant_type = WNA16_SUPPORTED_TYPES_MAP [self .num_bits ]
14491461 self .use_marlin = True
1462+ self .marlin_input_dtype = get_marlin_input_dtype (layer_name )
14501463
14511464 def create_weights (
14521465 self ,
@@ -1816,35 +1829,26 @@ def apply(
18161829class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
18171830 def __init__ (
18181831 self ,
1819- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
1832+ weight_quant : QuantizationArgs ,
1833+ input_quant : QuantizationArgs | None ,
18201834 moe : FusedMoEConfig ,
18211835 layer_name : str | None = None ,
18221836 ):
18231837 super ().__init__ (moe )
1824- self .quant_config = quant_config
1825- # TODO: @dsikka: refactor this to use schemes as other kernels
1826- # are supported + check if the layer is being ignored.
1827- config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
1828- self .num_bits = config .num_bits
1829- self .packed_factor = 32 // config .num_bits
1830- self .strategy = config .strategy
1838+ self .weight_quant = weight_quant
1839+ self .input_quant = input_quant
1840+ # Extract properties from weight_quant
1841+ self .num_bits = weight_quant .num_bits
1842+ self .packed_factor = 32 // weight_quant .num_bits
1843+ self .strategy = weight_quant .strategy
18311844 # channelwise is not supported by this kernel
1832- assert config .strategy == "group"
1833- self .group_size = config .group_size
1845+ assert weight_quant .strategy == "group"
1846+ self .group_size = weight_quant .group_size
18341847 # grouped actorder isn't supported by this kernel
1835- assert config .actorder != "group"
1836- assert config .symmetric , "Only symmetric quantization is supported for MoE"
1837-
1838- if not (
1839- self .quant_config .quant_format == CompressionFormat .pack_quantized .value
1840- and self .num_bits in WNA16_SUPPORTED_BITS
1841- ):
1842- raise ValueError (
1843- "For Fused MoE layers, only " ,
1844- f"{ CompressionFormat .pack_quantized .value } " ,
1845- "is supported for the following bits: " ,
1846- f"{ WNA16_SUPPORTED_BITS } " ,
1847- )
1848+ assert weight_quant .actorder != "group"
1849+ assert weight_quant .symmetric , (
1850+ "Only symmetric quantization is supported for MoE"
1851+ )
18481852
18491853 def create_weights (
18501854 self ,
@@ -2070,28 +2074,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
20702074
20712075 def __init__ (
20722076 self ,
2073- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
2077+ weight_quant : QuantizationArgs ,
2078+ input_quant : QuantizationArgs ,
20742079 moe : FusedMoEConfig ,
20752080 layer_name : str | None = None ,
20762081 ):
20772082 super ().__init__ (moe )
20782083 self .has_bias = self .moe .has_bias
2079- self .quant_config = quant_config
2084+ self .weight_quant = weight_quant
2085+ self .input_quant = input_quant
20802086
20812087 # Validate scheme: weights=W4 (channel or group),
20822088 # activations=dynamic TOKEN (A8)
2083- wq = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
2084- aq = self .quant_config .target_scheme_map ["Linear" ].get ("input_activations" )
20852089
20862090 # Must be dynamic per-token activations
2087- if aq .strategy != QuantizationStrategy .TOKEN or not aq .dynamic :
2091+ if (
2092+ input_quant .strategy != QuantizationStrategy .TOKEN
2093+ or not input_quant .dynamic
2094+ ):
20882095 raise ValueError (
20892096 "W4A8-int MoE needs dynamic per-token activation quantization."
20902097 )
20912098
20922099 # Weight can be channel-wise (group_size=None) or group-wise
2093- self .group_size = wq .group_size if (wq .group_size is not None ) else - 1
2094- if wq .num_bits != 4 :
2100+ self .group_size = (
2101+ weight_quant .group_size if (weight_quant .group_size is not None ) else - 1
2102+ )
2103+ if weight_quant .num_bits != 4 :
20952104 raise ValueError ("This method only supports 4-bit weights (num_bits=4)." )
20962105
20972106 # CPU only
0 commit comments