Skip to content

Commit 517f8db

Browse files
zyongyesimon-mo
authored andcommitted
[gpt-oss] flashinfer mxfp4 (vllm-project#22339)
Signed-off-by: simon-mo <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Co-authored-by: simon-mo <[email protected]>
1 parent d03d33d commit 517f8db

File tree

5 files changed

+453
-3
lines changed

5 files changed

+453
-3
lines changed

vllm/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@
154154
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
155155
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
156156
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
157+
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
158+
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
157159

158160

159161
def get_default_cache_root():
@@ -932,6 +934,16 @@ def get_vllm_port() -> Optional[int]:
932934
"VLLM_USE_FLASHINFER_MOE_FP4":
933935
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))),
934936

937+
# If set to 1, use the FlashInfer
938+
# MXFP8 (activation) x MXFP4 (weight) MoE backend.
939+
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
940+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
941+
942+
# If set to 1, use the FlashInfer
943+
# BF16 (activation) x MXFP4 (weight) MoE backend.
944+
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
945+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))),
946+
935947
# Control the cache sized used by the xgrammar compiler. The default
936948
# of 512 MB should be enough for roughly 1000 JSON schemas.
937949
# It can be changed with this variable if needed for some reason.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from vllm.model_executor.utils import set_weight_attrs
3434
from vllm.platforms import current_platform
3535
from vllm.platforms.interface import CpuArchEnum
36-
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
36+
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
37+
round_up)
3738
from vllm.utils.flashinfer import has_flashinfer
3839

3940
if current_platform.is_cuda_alike():
@@ -719,6 +720,12 @@ def __init__(
719720

720721
self.global_num_experts = num_experts + num_redundant_experts
721722

723+
# we padding globally so EP buffer allocation works
724+
if quant_config and quant_config.get_name() == "mxfp4" and (
725+
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
726+
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
727+
hidden_size = round_up(hidden_size, 256)
728+
722729
# For smuggling this layer into the fused moe custom op
723730
compilation_config = vllm_config.compilation_config
724731
if prefix in compilation_config.static_forward_context:
@@ -1064,6 +1071,18 @@ def weight_loader(self,
10641071
shard_id: str,
10651072
expert_id: int,
10661073
return_success: bool = False) -> Optional[bool]:
1074+
1075+
if self.quant_config and self.quant_config.get_name() == "mxfp4":
1076+
# (FIXME) for gpt-oss all experts are combined
1077+
if "bias" in weight_name:
1078+
dim1 = loaded_weight.shape[1]
1079+
param.data[:, :dim1].copy_(loaded_weight)
1080+
else:
1081+
dim1 = loaded_weight.shape[1]
1082+
dim2 = loaded_weight.shape[2]
1083+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
1084+
return True if return_success else None
1085+
10671086
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
10681087
if expert_id == -1:
10691088
# Failed to load this param since it's not local to this rank
@@ -1476,13 +1495,20 @@ def maybe_all_reduce_tensor_model_parallel(
14761495

14771496
def forward(self, hidden_states: torch.Tensor,
14781497
router_logits: torch.Tensor):
1498+
og_hidden_states = hidden_states.shape[-1]
1499+
if self.hidden_size != og_hidden_states:
1500+
hidden_states = F.pad(hidden_states,
1501+
(0, self.hidden_size - og_hidden_states),
1502+
mode='constant',
1503+
value=0.0)
14791504
# TODO: Once the OOM issue for the TPU backend is resolved, we will
14801505
# switch to using the moe_forward custom op.
14811506
if current_platform.is_tpu():
14821507
return self.forward_impl(hidden_states, router_logits)
14831508
else:
1484-
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
1485-
self.layer_name)
1509+
return torch.ops.vllm.moe_forward(
1510+
hidden_states, router_logits,
1511+
self.layer_name)[..., :og_hidden_states]
14861512

14871513
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
14881514
full_router_logits: torch.Tensor):

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"auto-round",
3838
"rtn",
3939
"inc",
40+
"mxfp4",
4041
]
4142
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
4243

@@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
110111
from .marlin import MarlinConfig
111112
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
112113
from .moe_wna16 import MoeWNA16Config
114+
from .mxfp4 import Mxfp4Config
113115
from .neuron_quant import NeuronQuantConfig
114116
from .ptpc_fp8 import PTPCFp8Config
115117
from .qqq import QQQConfig
@@ -148,6 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
148150
"auto-round": AutoRoundConfig,
149151
"rtn": RTNConfig,
150152
"inc": INCConfig,
153+
"mxfp4": Mxfp4Config,
151154
}
152155
# Update the `method_to_config` with customized quantization methods.
153156
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

0 commit comments

Comments
 (0)