Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_CK_MXFP4_MOE: bool = False
VLLM_ROCM_USE_EMULATED_MXFP4_MOE: bool = False
VLLM_TRITON_FP4_GEMM_USE_ASM: bool = False
VLLM_USE_AITER_TRITON_ROPE: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
Expand Down Expand Up @@ -801,6 +803,14 @@
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),

"VLLM_ROCM_USE_CK_MXFP4_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_CK_MXFP4_MOE", "False").lower() in

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably toggle this to True by default

("true", "1")),

"VLLM_ROCM_USE_EMULATED_MXFP4_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_EMULATED_MXFP4_MOE", "False").lower() in
("true", "1")),

# Whether to use aiter fp4 gemm asm.
# By default is disabled.
"VLLM_TRITON_FP4_GEMM_USE_ASM":
Expand Down Expand Up @@ -1237,15 +1247,15 @@
# Use AITER Triton fused RMSNORM + Quantization
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),

Check failure on line 1250 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1250:81: E501 Line too long (92 > 80)
# Use AITER Triton fused elementwise multiply + elementwise addtion
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD", "1"))),

Check failure on line 1254 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1254:81: E501 Line too long (82 > 80)
# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))),

Check failure on line 1258 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1258:81: E501 Line too long (94 > 80)
# Use AITER Triton fused FP8 per-token group quant + FP8 batched GEMM
"VLLM_ROCM_USE_AITER_TRITON_FP8_BMM":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FP8_BMM", "1"))),
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
dequant_mxfp4, use_fp4_aiter_moe)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
Expand Down Expand Up @@ -530,7 +530,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
elif use_mxfp4_w4a4:
elif use_mxfp4_w4a4 and use_fp4_aiter_moe():
assert A_scale is not None
assert B_scale is not None
else:
Expand Down Expand Up @@ -1659,7 +1659,7 @@ def fused_experts_impl(
else:
out_hidden_states = torch.empty_like(hidden_states)

if use_mxfp4_w4a4 and not current_platform.supports_mx():
if use_mxfp4_w4a4 and not use_fp4_aiter_moe():
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
Expand Down Expand Up @@ -1719,7 +1719,7 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4
and current_platform.supports_mx(),
and use_fp4_aiter_moe(),
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias)
Expand Down Expand Up @@ -1771,7 +1771,7 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4
and current_platform.supports_mx(),
and use_fp4_aiter_moe(),
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w2_bias)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
quant_dequant_mxfp4, use_fp4_aiter_moe)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_quantize)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils.flashinfer import fp4_quantize

if current_platform.supports_mx():
if use_fp4_aiter_moe():
from aiter.ops.triton.quant import dynamic_mxfp4_quant


Expand Down Expand Up @@ -174,7 +174,7 @@ def _mxfp4_quantize(
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert block_shape is None
if not current_platform.supports_mx():
if not use_fp4_aiter_moe():
A = quant_dequant_mxfp4(A)
return A, A_scale
if A_scale is not None:
Expand Down
107 changes: 72 additions & 35 deletions vllm/model_executor/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE)
OCP_MX_BLOCK_SIZE, use_fp4_aiter_moe)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -293,7 +294,7 @@ def __init__(
"QuarkW4A4MXFp4MoEMethod with static input scales is currently "
"not implemented. Please open an issue.")

if not current_platform.supports_mx():
if not use_fp4_aiter_moe():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
Expand Down Expand Up @@ -360,6 +361,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)


def process_weights_after_loading(self, layer):
if envs.VLLM_ROCM_USE_CK_MXFP4_MOE:
from aiter.utility.fp4_utils import e8m0_shuffle

# Pre-shuffle weight scales
s0, s1, _ = layer.w13_weight_scale.shape
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)

s0, s1, _ = layer.w2_weight_scale.shape
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
torch.cuda.empty_cache()

def apply(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -388,37 +406,56 @@ def apply(
raise NotImplementedError(
"EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.")

from vllm.model_executor.layers.fused_moe import fused_experts

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)

out = fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_mxfp4_w4a4=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
activation=activation,
)
if envs.VLLM_ROCM_USE_CK_MXFP4_MOE:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe

out = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(ActivationType.Silu
if activation == "silu" else ActivationType.Gelu),
doweight_stage1=False,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts


topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)

out = fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_mxfp4_w4a4=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
activation=activation,
)
return out
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache
from typing import Callable, Optional

import torch
Expand All @@ -13,6 +14,9 @@

OCP_MX_BLOCK_SIZE = 32

@cache
def use_fp4_aiter_moe():
return current_platform.supports_mx() and envs.VLLM_ROCM_USE_AITER and not envs.VLLM_ROCM_USE_EMULATED_MXFP4_MOE

def _swizzle_mxfp4(quant_tensor, scale, num_warps):
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
Expand All @@ -33,7 +37,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
scale_layout, scale_layout_opts = StridedLayout, dict()

elif current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950
from vllm.platforms.rocm import on_gfx950
from triton_kernels.target_info import is_hip
from triton_kernels.tensor_details.layout import (
BlackwellMXScaleLayout, GFX950MXScaleLayout, HopperMXScaleLayout,
Expand Down
Loading