Skip to content

Commit caea443

Browse files
mawong-amdDoug Lehr
authored andcommitted
Integrate mxfp4 MoE native kernels
Signed-off-by: Matthew Wong <[email protected]>
1 parent 48dc133 commit caea443

File tree

4 files changed

+80
-20
lines changed

4 files changed

+80
-20
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444

4545
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
4646

47+
try:
48+
from aiter.ops.triton.moe_op_mxfp4 import _fused_moe_kernel_mxfp4
49+
except ImportError:
50+
_fused_moe_kernel_mxfp4 = None
51+
4752
logger = init_logger(__name__)
4853

4954

@@ -507,6 +512,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
507512
use_int8_w8a8: bool,
508513
use_int8_w8a16: bool,
509514
use_int4_w4a16: bool,
515+
use_mxfp4_w4a4: bool,
510516
per_channel_quant: bool,
511517
block_shape: Optional[list[int]] = None,
512518
B_bias: Optional[torch.Tensor] = None) -> None:
@@ -524,6 +530,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
524530
elif use_int8_w8a16 or use_int4_w4a16:
525531
assert B_scale is not None
526532
assert block_shape is None or block_shape[0] == 0
533+
elif use_mxfp4_w4a4:
534+
assert A_scale is not None
535+
assert B_scale is not None
527536
else:
528537
assert A_scale is None
529538
assert B_scale is None
@@ -611,6 +620,55 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
611620
use_int8_w8a16=use_int8_w8a16,
612621
**config,
613622
)
623+
elif use_mxfp4_w4a4:
624+
ONE = torch.ones(B.size(0), dtype=torch.float32, device=A.device)
625+
# overwrite config with a static one for now
626+
config = {
627+
"BLOCK_SIZE_M": 128,
628+
"BLOCK_SIZE_N": 128,
629+
"BLOCK_SIZE_K": 128,
630+
"GROUP_SIZE_M": 4,
631+
"num_warps": 8,
632+
"num_stages": 2,
633+
"waves_per_eu": 0,
634+
"matrix_instr_nonkdim": 16,
635+
"kpack": 1,
636+
}
637+
_fused_moe_kernel_mxfp4[grid](
638+
A,
639+
B,
640+
C,
641+
ONE[0],
642+
ONE,
643+
A_scale,
644+
B_scale,
645+
topk_weights,
646+
sorted_token_ids,
647+
expert_ids,
648+
num_tokens_post_padded,
649+
B.size(1),
650+
A.size(1),
651+
EM,
652+
num_tokens,
653+
A.stride(0),
654+
A.stride(1),
655+
B.stride(0),
656+
B.stride(2),
657+
B.stride(1),
658+
C.stride(1),
659+
C.stride(2),
660+
A_scale.stride(0),
661+
A_scale.stride(1),
662+
B_scale.stride(0),
663+
B_scale.stride(2),
664+
B_scale.stride(1),
665+
MUL_ROUTED_WEIGHT=mul_routed_weight,
666+
top_k=top_k,
667+
compute_type=compute_type,
668+
SWIZZLE_MX_A=False,
669+
SWIZZLE_MX_B=False,
670+
**config,
671+
)
614672
else:
615673
config = config.copy()
616674
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
@@ -1570,7 +1628,7 @@ def fused_experts_impl(
15701628
else:
15711629
out_hidden_states = torch.empty_like(hidden_states)
15721630

1573-
if use_mxfp4_w4a4:
1631+
if use_mxfp4_w4a4 and not current_platform.supports_mx():
15741632
# Weight has to be dequantized for mxfp4 emulation.
15751633
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
15761634
w1_scale = None
@@ -1629,6 +1687,8 @@ def fused_experts_impl(
16291687
use_int8_w8a8=use_int8_w8a8,
16301688
use_int8_w8a16=use_int8_w8a16,
16311689
use_int4_w4a16=use_int4_w4a16,
1690+
use_mxfp4_w4a4=use_mxfp4_w4a4
1691+
and current_platform.supports_mx(),
16321692
per_channel_quant=per_channel_quant,
16331693
block_shape=block_shape,
16341694
B_bias=w1_bias)
@@ -1687,6 +1747,8 @@ def swiglu_oai(gate_up):
16871747
use_int8_w8a8=use_int8_w8a8,
16881748
use_int8_w8a16=use_int8_w8a16,
16891749
use_int4_w4a16=use_int4_w4a16,
1750+
use_mxfp4_w4a4=use_mxfp4_w4a4
1751+
and current_platform.supports_mx(),
16901752
per_channel_quant=per_channel_quant,
16911753
block_shape=block_shape,
16921754
B_bias=w2_bias)
@@ -1994,6 +2056,8 @@ def apply(
19942056
use_int8_w8a8=self.use_int8_w8a8,
19952057
use_int8_w8a16=self.use_int8_w8a16,
19962058
use_int4_w4a16=self.use_int4_w4a16,
2059+
use_mxfp4_w4a4=self.use_mxfp4_w4a4
2060+
and current_platform.supports_mx(),
19972061
per_channel_quant=self.per_act_token_quant,
19982062
block_shape=self.block_shape,
19992063
B_bias=None # TODO support B_bias
@@ -2027,6 +2091,8 @@ def apply(
20272091
use_int8_w8a8=self.use_int8_w8a8,
20282092
use_int8_w8a16=self.use_int8_w8a16,
20292093
use_int4_w4a16=self.use_int4_w4a16,
2094+
use_mxfp4_w4a4=self.use_mxfp4_w4a4
2095+
and current_platform.supports_mx(),
20302096
per_channel_quant=self.per_act_token_quant,
20312097
block_shape=self.block_shape,
20322098
B_bias=None # TODO support B_bias

vllm/model_executor/layers/fused_moe/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from vllm.utils import cdiv
1818
from vllm.utils.flashinfer import fp4_quantize
1919

20+
if current_platform.supports_mx():
21+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
22+
2023

2124
@triton.jit
2225
def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts,
@@ -167,14 +170,14 @@ def _mxfp4_quantize(
167170
A_scale: Optional[torch.Tensor],
168171
per_act_token_quant: bool,
169172
block_shape: Optional[list[int]] = None,
170-
) -> tuple[torch.Tensor, None]:
173+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
171174
assert block_shape is None
172175
if not current_platform.supports_mx():
173176
A = quant_dequant_mxfp4(A)
174-
else:
175-
raise NotImplementedError()
176-
177-
return A, None
177+
return A, A_scale
178+
if A_scale is not None:
179+
return A, A_scale
180+
return dynamic_mxfp4_quant(A)
178181

179182

180183
def moe_kernel_quantize_input(

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,7 @@ def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
282282
"QDQ (quantize and dequantize) will be used, with the linear "
283283
"layers computed in high precision.")
284284
else:
285-
self.emulate = True
286-
logger.warning_once(
287-
"The current platform supports native MXFP4 "
288-
"computation, but kernels are not yet integrated in vLLM. "
289-
"Simulated weight dequantization and activation "
290-
"QDQ (quantize and dequantize) will be used, with the linear "
291-
"layers computed in high precision.")
285+
self.emulate = False
292286

293287
def create_weights(self, layer: torch.nn.Module, num_experts: int,
294288
hidden_size: int, intermediate_size_per_partition: int,

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
144144
)
145145
weight_quantizer.scale.data = layer.weight_scale.data
146146

147-
if not envs.VLLM_QUARK_EMU_MEM_OPT:
148-
layer.weight = torch.nn.Parameter(
149-
weight_quantizer(layer.weight.data).to(self.out_dtype),
150-
requires_grad=False,
151-
)
152-
else:
153-
self.weight_quantizer = weight_quantizer
147+
layer.weight = torch.nn.Parameter(
148+
weight_quantizer(layer.weight.data).to(self.out_dtype),
149+
requires_grad=False,
150+
)
154151
layer.weight_scale = None
155152

156153
# This call is necessary to release the scales memory.

0 commit comments

Comments
 (0)