Skip to content

Commit c1ffcb5

Browse files
authored
[Refactor] Optimize FP8 MOE Backend Choice and Log (vllm-project#26044)
Signed-off-by: yewentao256 <[email protected]>
1 parent 0879736 commit c1ffcb5

File tree

1 file changed

+71
-46
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+71
-46
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from enum import Enum
45
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
56

67
import torch
@@ -68,6 +69,65 @@
6869
logger = init_logger(__name__)
6970

7071

72+
class Fp8MoeBackend(Enum):
73+
NONE = 0
74+
FLASHINFER_TRTLLM = 1
75+
FLASHINFER_CUTLASS = 2
76+
DEEPGEMM = 3
77+
CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
78+
MARLIN = 5
79+
TRITON = 6
80+
81+
82+
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
83+
"""
84+
Select the primary FP8 MoE backend
85+
Note: Shape-specific fallbacks may still occur at runtime.
86+
"""
87+
# prefer FlashInfer backends when available and enabled on supported GPUs
88+
if (current_platform.is_cuda()
89+
and current_platform.is_device_capability(100)
90+
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
91+
backend = get_flashinfer_moe_backend()
92+
if backend == FlashinferMoeBackend.TENSORRT_LLM:
93+
logger.info_once(
94+
"Using FlashInfer FP8 MoE TRTLLM backend for SM100")
95+
return Fp8MoeBackend.FLASHINFER_TRTLLM
96+
else:
97+
logger.info_once(
98+
"Using FlashInfer FP8 MoE CUTLASS backend for SM100")
99+
return Fp8MoeBackend.FLASHINFER_CUTLASS
100+
101+
# weight-only path for older GPUs without native FP8
102+
use_marlin = (not current_platform.has_device_capability(89)
103+
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
104+
if current_platform.is_rocm():
105+
use_marlin = False
106+
if use_marlin:
107+
logger.info_once("Using Marlin backend for FP8 MoE")
108+
return Fp8MoeBackend.MARLIN
109+
110+
# deepGEMM on supported platforms with block-quantized weights
111+
if envs.VLLM_USE_DEEP_GEMM and block_quant:
112+
if not has_deep_gemm():
113+
logger.warning_once(
114+
"DeepGEMM backend requested but not available.")
115+
elif is_deep_gemm_supported():
116+
logger.info_once("Using DeepGEMM backend for FP8 MoE")
117+
return Fp8MoeBackend.DEEPGEMM
118+
119+
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
120+
if (current_platform.is_cuda()
121+
and current_platform.is_device_capability(100) and block_quant):
122+
logger.info_once(
123+
"Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
124+
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
125+
126+
# default to Triton
127+
logger.info_once("Using Triton backend for FP8 MoE")
128+
return Fp8MoeBackend.TRITON
129+
130+
71131
class Fp8Config(QuantizationConfig):
72132
"""Config class for FP8."""
73133

@@ -453,54 +513,19 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
453513
self.fused_experts: Optional[
454514
mk.FusedMoEModularKernel] = None # type: ignore
455515

456-
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
457-
# kernel for fast weight-only FP8 quantization
458-
self.use_marlin = (not current_platform.has_device_capability(89)
459-
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
460-
# Disable marlin for rocm
461-
if current_platform.is_rocm():
462-
self.use_marlin = False
516+
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
463517

464-
# First check for Flashinfer MOE on Blackwell GPUs
518+
self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN)
465519
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
466-
if (current_platform.is_cuda()
467-
and current_platform.is_device_capability(100)
468-
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
469-
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
470-
logger.info_once(
471-
f"Detected Blackwell GPUs, using FlashInfer "
472-
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.")
473-
474-
# Check for DeepGemm support.
475-
self.allow_deep_gemm = False
476-
if envs.VLLM_USE_DEEP_GEMM:
477-
if not has_deep_gemm():
478-
logger.warning_once("Failed to import DeepGemm kernels.")
479-
elif not self.block_quant:
480-
logger.warning_once("Model is not block quantized. Not using"
481-
" DeepGemm kernels")
482-
elif self.flashinfer_moe_backend:
483-
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
484-
" enabled.")
485-
elif (is_deep_gemm_supported()):
486-
logger.debug_once(
487-
"DeepGemm kernels available for Fp8MoEMethod.")
488-
self.allow_deep_gemm = True
489-
else:
490-
logger.warning_once(
491-
"DeepGemm not supported on the current platform.")
492-
493-
# Check for CutlassBlockScaledGroupedGemm support.
494-
self.allow_cutlass_block_scaled_grouped_gemm = False
495-
if not self.block_quant:
496-
logger.debug_once("Model is not block quantized. Not using "
497-
"CutlassBlockScaledGroupedGemm kernels")
498-
elif (current_platform.is_cuda()
499-
and current_platform.is_device_capability(100)
500-
and not self.flashinfer_moe_backend):
501-
logger.debug_once(
502-
"CutlassBlockScaledGroupedGemm available for Fp8MoEMethod.")
503-
self.allow_cutlass_block_scaled_grouped_gemm = True
520+
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
521+
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
522+
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
523+
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
524+
525+
self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM)
526+
self.allow_cutlass_block_scaled_grouped_gemm = (
527+
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
528+
)
504529

505530
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
506531
intermediate_size_per_partition: int,

0 commit comments

Comments
 (0)