From f5d4a13fbf3832614f15b24708ec199f9772f630 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 30 Sep 2025 16:28:58 -0500 Subject: [PATCH 1/2] Quick port of fp4 fusedmoe --- vllm/envs.py | 10 ++ .../layers/fused_moe/fused_moe.py | 10 +- vllm/model_executor/layers/fused_moe/utils.py | 6 +- .../layers/quantization/quark/quark_moe.py | 106 ++++++++++++------ .../layers/quantization/utils/mxfp4_utils.py | 6 +- 5 files changed, 94 insertions(+), 44 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 9ff40fb99cab..ce6b94626690 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -102,6 +102,8 @@ VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_CK_MXFP4_MOE: bool = False + VLLM_ROCM_USE_EMULATED_MXFP4_MOE: bool = False VLLM_TRITON_FP4_GEMM_USE_ASM: bool = False VLLM_USE_AITER_TRITON_ROPE: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True @@ -801,6 +803,14 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")), + "VLLM_ROCM_USE_CK_MXFP4_MOE": + lambda: (os.getenv("VLLM_ROCM_USE_CK_MXFP4_MOE", "False").lower() in + ("true", "1")), + + "VLLM_ROCM_USE_EMULATED_MXFP4_MOE": + lambda: (os.getenv("VLLM_ROCM_USE_EMULATED_MXFP4_MOE", "False").lower() in + ("true", "1")), + # Whether to use aiter fp4 gemm asm. # By default is disabled. "VLLM_TRITON_FP4_GEMM_USE_ASM": diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3c94f7ef7983..e50103690122 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - dequant_mxfp4) + dequant_mxfp4, use_fp4_aiter_moe) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer @@ -530,7 +530,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None assert block_shape is None or block_shape[0] == 0 - elif use_mxfp4_w4a4: + elif use_mxfp4_w4a4 and use_fp4_aiter_moe(): assert A_scale is not None assert B_scale is not None else: @@ -1659,7 +1659,7 @@ def fused_experts_impl( else: out_hidden_states = torch.empty_like(hidden_states) - if use_mxfp4_w4a4 and not current_platform.supports_mx(): + if use_mxfp4_w4a4 and not use_fp4_aiter_moe(): # Weight has to be dequantized for mxfp4 emulation. w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) w1_scale = None @@ -1719,7 +1719,7 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, use_mxfp4_w4a4=use_mxfp4_w4a4 - and current_platform.supports_mx(), + and use_fp4_aiter_moe(), per_channel_quant=per_channel_quant, block_shape=block_shape, B_bias=w1_bias) @@ -1771,7 +1771,7 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, use_mxfp4_w4a4=use_mxfp4_w4a4 - and current_platform.supports_mx(), + and use_fp4_aiter_moe(), per_channel_quant=per_channel_quant, block_shape=block_shape, B_bias=w2_bias) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 56b24d2691bd..c2946dbfab37 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - quant_dequant_mxfp4) + quant_dequant_mxfp4, use_fp4_aiter_moe) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_quantize) from vllm.platforms import current_platform @@ -19,7 +19,7 @@ from vllm.utils import cdiv from vllm.utils.flashinfer import fp4_quantize -if current_platform.supports_mx(): +if use_fp4_aiter_moe(): from aiter.ops.triton.quant import dynamic_mxfp4_quant @@ -174,7 +174,7 @@ def _mxfp4_quantize( block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert block_shape is None - if not current_platform.supports_mx(): + if not use_fp4_aiter_moe(): A = quant_dequant_mxfp4(A) return A, A_scale if A_scale is not None: diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index e4771056cc4e..1786d04c4bd8 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -11,7 +11,7 @@ FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE) + OCP_MX_BLOCK_SIZE, use_fp4_aiter_moe) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -293,7 +293,7 @@ def __init__( "QuarkW4A4MXFp4MoEMethod with static input scales is currently " "not implemented. Please open an issue.") - if not current_platform.supports_mx(): + if not use_fp4_aiter_moe(): self.emulate = True logger.warning_once( "The current platform does not support native MXFP4 " @@ -360,6 +360,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + + def process_weights_after_loading(self, layer): + if envs.VLLM_ROCM_USE_CK_MXFP4_MOE: + from aiter.utility.fp4_utils import e8m0_shuffle + + # Pre-shuffle weight scales + s0, s1, _ = layer.w13_weight_scale.shape + w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) + w13_weight_scale = e8m0_shuffle(w13_weight_scale) + layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) + + s0, s1, _ = layer.w2_weight_scale.shape + w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) + w2_weight_scale = e8m0_shuffle(w2_weight_scale) + layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + torch.cuda.empty_cache() + def apply( self, layer: torch.nn.Module, @@ -388,37 +405,56 @@ def apply( raise NotImplementedError( "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.") - from vllm.model_executor.layers.fused_moe import fused_experts - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) - - out = fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_mxfp4_w4a4=True, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, - activation=activation, - ) + if envs.VLLM_ROCM_USE_CK_MXFP4_MOE: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + out = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=(ActivationType.Silu + if activation == "silu" else ActivationType.Gelu), + doweight_stage1=False, + ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) + + out = fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_mxfp4_w4a4=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, + activation=activation, + ) return out diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index c0a1110de221..e874a2aaace1 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cache from typing import Callable, Optional import torch @@ -13,6 +14,9 @@ OCP_MX_BLOCK_SIZE = 32 +@cache +def use_fp4_aiter_moe(): + return current_platform.supports_mx() and envs.VLLM_ROCM_USE_AITER and not envs.VLLM_ROCM_USE_EMULATED_MXFP4_MOE def _swizzle_mxfp4(quant_tensor, scale, num_warps): """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel @@ -33,7 +37,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): scale_layout, scale_layout_opts = StridedLayout, dict() elif current_platform.is_rocm(): - from vllm.platforms.rocm import on_gfx950 + from vllm.platforms.rocm import on_gfx950 from triton_kernels.target_info import is_hip from triton_kernels.tensor_details.layout import ( BlackwellMXScaleLayout, GFX950MXScaleLayout, HopperMXScaleLayout, From 9d49387c3897d8f60fbac2b732f219a9ed307e55 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Wed, 1 Oct 2025 13:46:10 -0500 Subject: [PATCH 2/2] Add missing import --- vllm/model_executor/layers/quantization/quark/quark_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 1786d04c4bd8..3b9f0a0db4cf 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,6 +5,7 @@ import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,