Skip to content

Commit 61bcfdb

Browse files
Add MXFP8 MOE/Linear and MXFP4 Linear (#1034)
Signed-off-by: yiliu30 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7345fe5 commit 61bcfdb

17 files changed

+915
-126
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
- Build and Install vLLM
2+
3+
```
4+
git clone --branch fused-moe-ar https://github.com/yiliu30/vllm-fork.git
5+
VLLM_USE_PRECOMPILED=1 pip install --editable . -vvv
6+
```
7+
8+
9+
- Enable vLLM-Ext at Runtime
10+
```bash
11+
VLLM_ENABLE_AR_EXT=1 vllm serve ...
12+
```

auto_round_extension/vllm_ext/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919

2020
def apply():
21-
import vllm.model_executor.layers.quantization.auto_round as auto_round_module
21+
import auto_round_extension.vllm_ext.auto_round_ext
22+
import auto_round_extension.vllm_ext.envs_ext
2223

23-
from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig
24-
25-
auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig
26-
from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables
24+
print("*****************************************************************************")
25+
print("* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *")
26+
print("*****************************************************************************")

auto_round_extension/vllm_ext/apply_ext.sh

Lines changed: 0 additions & 46 deletions
This file was deleted.

auto_round_extension/vllm_ext/auto_round_ext.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,26 @@
1818
from vllm.logger import init_logger
1919
from vllm.model_executor.layers.fused_moe import FusedMoE
2020
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
21-
from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig
21+
from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig as _BaseAutoRoundConfig
2222

2323
from auto_round.schemes import QuantizationScheme
24+
from auto_round_extension.vllm_ext.quant_method_linear import AutoRoundQuantLinearMethod
2425
from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod
2526

2627
logger = init_logger(__name__)
2728

2829

29-
class AutoRoundExtensionConfig(AutoRoundConfig):
30-
SUPPORTED_DTYPES = AutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"})
31-
SUPPORTED_FORMATS = AutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"})
30+
class AutoRoundExtensionConfig(_BaseAutoRoundConfig):
31+
SUPPORTED_DTYPES = _BaseAutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"})
32+
SUPPORTED_FORMATS = _BaseAutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"})
3233

3334
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
3435
# FIXME: (yi) make it compatible with `AutoRoundConfig`
3536
if isinstance(layer, FusedMoE):
3637
quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix)
3738
return quant_method
3839
elif isinstance(layer, LinearBase):
39-
return UnquantizedLinearMethod()
40+
return AutoRoundQuantLinearMethod.get_method(self, layer, prefix)
4041
else:
4142
return None
4243

@@ -48,7 +49,7 @@ def _parse_quant_scheme(config: dict):
4849
return quant_scheme
4950

5051
@classmethod
51-
def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig:
52+
def from_config(cls, config: dict[str, Any]) -> _BaseAutoRoundConfig:
5253
ar_config = super().from_config(config)
5354
# TODO: (yi) refine below implementation
5455
quant_scheme = AutoRoundExtensionConfig._parse_quant_scheme(config)
@@ -61,3 +62,9 @@ def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig:
6162
ar_config.quant_scheme = quant_scheme
6263
ar_config.layer_schemes = layer_schemes
6364
return ar_config
65+
66+
67+
# Patch vLLM’s AutoRoundConfig at import time
68+
import vllm.model_executor.layers.quantization.auto_round as _auto_round_module
69+
70+
_auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig

auto_round_extension/vllm_ext/envs_ext.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
# Define extra environment variables
2323
extra_environment_variables: dict[str, Callable[[], Any]] = {
24-
"VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"),
25-
"VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"),
26-
"VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"),
24+
"VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0") in ("1", "true", "True"),
25+
"VLLM_MXFP4_PRE_UNPACK_TO_FP8": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "1") in ("1", "true", "True"),
26+
"VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "0") in ("1", "true", "True"),
27+
"VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "1") in ("1", "true", "True"),
28+
"VLLM_AR_POST_PROCESS_GPTOSS": lambda: os.getenv("VLLM_AR_POST_PROCESS_GPTOSS", "0") in ("1", "true", "True"),
2729
}
2830
# Add the extra environment variables to vllm.envs
2931
import vllm.envs as envs

auto_round_extension/vllm_ext/fp4_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
5151
indices = indices.reshape(-1)
5252

5353
# Handle odd length by padding if necessary
54-
assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}"
54+
# assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}"
5555

5656
# Reshape to pair consecutive elements
5757
indices = indices.reshape(-1, 2)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# SPDX-License-Identifier: Apache-2.0
16+
from typing import Callable, Optional
17+
18+
import torch
19+
import vllm.envs as envs
20+
from torch.nn.parameter import Parameter
21+
from vllm.logger import init_logger
22+
from vllm.model_executor.parameter import GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter
23+
from vllm.platforms import current_platform
24+
25+
from auto_round_extension.vllm_ext.mxfp4_qdq_utils import (
26+
dequant_mxfp4_to_fp8,
27+
mxfp4_gemm_with_unpacked_weight,
28+
run_mxfp4_emulations,
29+
)
30+
31+
logger = init_logger(__name__)
32+
33+
__all__ = ["AutoRoundMXFP4LinearImpl"]
34+
35+
from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl
36+
37+
38+
class AutoRoundMXFP4LinearImpl(AutoRoundQuantImpl):
39+
def __init__(self, quant_scheme):
40+
self.quant_scheme = quant_scheme
41+
self.group_size = 32
42+
43+
@classmethod
44+
def get_min_capability(cls) -> int:
45+
if envs.VLLM_USE_MXFP4_CT_EMULATIONS:
46+
return 80
47+
return 100
48+
49+
def create_weights(
50+
self,
51+
layer: torch.nn.Module,
52+
output_partition_sizes: list[int],
53+
input_size_per_partition: int,
54+
params_dtype: torch.dtype,
55+
weight_loader: Callable,
56+
**kwargs,
57+
):
58+
output_size_per_partition = sum(output_partition_sizes)
59+
layer.logical_widths = output_partition_sizes
60+
layer.input_size_per_partition = input_size_per_partition
61+
layer.output_size_per_partition = output_size_per_partition
62+
63+
# Weight
64+
weight = ModelWeightParameter(
65+
data=torch.empty(sum(output_partition_sizes), input_size_per_partition // 2, dtype=torch.uint8),
66+
input_dim=1,
67+
output_dim=0,
68+
weight_loader=weight_loader,
69+
)
70+
layer.register_parameter("weight_packed", weight)
71+
72+
# Per Group Weight Scale
73+
weight_scale = GroupQuantScaleParameter(
74+
data=torch.empty(
75+
sum(output_partition_sizes),
76+
input_size_per_partition // self.group_size,
77+
# dtype=torch.uint8,
78+
dtype=torch.uint8,
79+
),
80+
input_dim=1,
81+
output_dim=0,
82+
weight_loader=weight_loader,
83+
)
84+
85+
layer.register_parameter("weight_scale", weight_scale)
86+
87+
def process_weights_after_loading(self, layer) -> None:
88+
# FIXME: may dequant to bf16
89+
if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:
90+
91+
weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8(
92+
data_lp=layer.weight_packed,
93+
scale_e8m0=layer.weight_scale,
94+
)
95+
del layer.weight_packed
96+
del layer.weight_scale
97+
layer.weight_packed = None
98+
layer.weight_scale = None
99+
layer.register_parameter(
100+
"weight_unpacked_fp8",
101+
torch.nn.Parameter(
102+
weight_fp8,
103+
requires_grad=False,
104+
),
105+
)
106+
layer.register_parameter(
107+
"weight_scale_bf16",
108+
torch.nn.Parameter(
109+
scale_bf16,
110+
requires_grad=False,
111+
),
112+
)
113+
114+
def apply_weights(
115+
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
116+
) -> torch.Tensor:
117+
if not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:
118+
out = run_mxfp4_emulations(x=x, weight=layer.weight_packed, weight_scale=layer.weight_scale)
119+
if bias is not None:
120+
out = out + bias
121+
return out
122+
else:
123+
out = mxfp4_gemm_with_unpacked_weight(
124+
x=x,
125+
weight_fp8=layer.weight_unpacked_fp8,
126+
weight_scale_bf16=layer.weight_scale_bf16,
127+
bias=bias,
128+
)
129+
return out
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Callable, Optional
16+
17+
import torch
18+
import vllm.envs as envs
19+
from vllm.model_executor.parameter import (
20+
GroupQuantScaleParameter,
21+
ModelWeightParameter,
22+
PerTensorScaleParameter,
23+
)
24+
25+
from auto_round_extension.vllm_ext.mxfp8_qdq_utils import dequant_mx_fp8, quant_mx_fp8
26+
from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl
27+
28+
29+
class AutoRoundMXFP8LinearImpl(AutoRoundQuantImpl):
30+
def __init__(self, quant_scheme):
31+
self.quant_scheme = quant_scheme
32+
self.strategy = "TENSOR_GROUP"
33+
self.out_dtype = torch.get_default_dtype()
34+
self.group_size = 32
35+
36+
@classmethod
37+
def get_min_capability(cls) -> int:
38+
return 80
39+
40+
def process_weights_after_loading(self, layer) -> None:
41+
return
42+
43+
def create_weights(
44+
self,
45+
layer: torch.nn.Module,
46+
output_partition_sizes: list[int],
47+
input_size_per_partition: int,
48+
params_dtype: torch.dtype,
49+
weight_loader: Callable,
50+
**kwargs,
51+
):
52+
# maybe_create_device_identity()
53+
54+
output_size_per_partition = sum(output_partition_sizes)
55+
layer.logical_widths = output_partition_sizes
56+
57+
# WEIGHT
58+
weight = ModelWeightParameter(
59+
data=torch.empty(
60+
output_size_per_partition,
61+
input_size_per_partition,
62+
dtype=torch.float8_e4m3fn,
63+
),
64+
input_dim=1,
65+
output_dim=0,
66+
weight_loader=weight_loader,
67+
)
68+
layer.register_parameter("weight", weight)
69+
70+
# WEIGHT SCALE
71+
# Per Group Weight Scale
72+
weight_scale = GroupQuantScaleParameter(
73+
data=torch.empty(
74+
sum(output_partition_sizes),
75+
input_size_per_partition // self.group_size,
76+
dtype=torch.uint8, # E8M0 for MXFP8 scale
77+
),
78+
input_dim=1,
79+
output_dim=0,
80+
weight_loader=weight_loader,
81+
)
82+
layer.register_parameter("weight_scale", weight_scale)
83+
84+
def apply_weights(
85+
self,
86+
layer: torch.nn.Module,
87+
x: torch.Tensor,
88+
bias: Optional[torch.Tensor] = None,
89+
) -> torch.Tensor:
90+
# dequant weight
91+
weight = layer.weight
92+
weight_scale = layer.weight_scale
93+
dequnat_weight = dequant_mx_fp8(
94+
weight_fp8=weight.data,
95+
scale_e8m0=weight_scale.data,
96+
block_size=self.group_size,
97+
target_dtype=x.dtype,
98+
)
99+
dequnat_weight = dequnat_weight.to(x.dtype)
100+
# if not envs.VLLM_AR_MXFP8_DISABLE_INPUT_QDQ:
101+
# q-dq input
102+
x_scale, x_quant = quant_mx_fp8(x)
103+
dequant_x = dequant_mx_fp8(
104+
weight_fp8=x_quant,
105+
scale_e8m0=x_scale,
106+
block_size=self.group_size,
107+
target_dtype=x.dtype,
108+
)
109+
x = dequant_x.to(x.dtype)
110+
111+
out = x @ dequnat_weight.t()
112+
return out.to(x.dtype) + (bias if bias is not None else 0)

0 commit comments

Comments
 (0)