Skip to content

Commit ef28354

Browse files
authored
[Bugfix] Fix accuracy issue of TRTLLM FP8 MOE and improve logging (vllm-project#25895)
Signed-off-by: Pavani Majety <[email protected]>
1 parent f4db5e6 commit ef28354

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -434,14 +434,9 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
434434
self.weight_block_size = self.quant_config.weight_block_size
435435
self.block_quant = self.weight_block_size is not None
436436

437-
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
438437
self.fused_experts: Optional[
439438
mk.FusedMoEModularKernel] = None # type: ignore
440-
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
441-
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
442-
logger.info_once(
443-
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
444-
)
439+
445440
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
446441
# kernel for fast weight-only FP8 quantization
447442
self.use_marlin = (not current_platform.has_device_capability(89)
@@ -450,14 +445,27 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
450445
if current_platform.is_rocm():
451446
self.use_marlin = False
452447

448+
# First check for Flashinfer MOE on Blackwell GPUs
449+
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
450+
if (current_platform.is_cuda()
451+
and current_platform.is_device_capability(100)
452+
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
453+
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
454+
logger.info_once(
455+
f"Detected Blackwell GPUs, using FlashInfer "
456+
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.")
457+
453458
# Check for DeepGemm support.
454459
self.allow_deep_gemm = False
455460
if envs.VLLM_USE_DEEP_GEMM:
456461
if not has_deep_gemm():
457462
logger.warning_once("Failed to import DeepGemm kernels.")
458463
elif not self.block_quant:
459-
logger.warning_once("Model is not block quantized. Not using "
460-
"DeepGemm kernels")
464+
logger.warning_once("Model is not block quantized. Not using"
465+
" DeepGemm kernels")
466+
elif self.flashinfer_moe_backend:
467+
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
468+
" enabled.")
461469
elif (is_deep_gemm_supported()):
462470
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
463471
self.allow_deep_gemm = True
@@ -471,15 +479,12 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
471479
logger.debug_once("Model is not block quantized. Not using "
472480
"CutlassBlockScaledGroupedGemm kernels")
473481
elif (current_platform.is_cuda()
474-
and current_platform.is_device_capability(100)):
482+
and current_platform.is_device_capability(100)
483+
and not self.flashinfer_moe_backend):
475484
logger.info_once(
476-
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
477-
)
485+
"Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE "
486+
"on SM100.")
478487
self.allow_cutlass_block_scaled_grouped_gemm = True
479-
else:
480-
logger.warning_once(
481-
"CutlassBlockScaledGroupedGemm not supported on the current "
482-
"platform.")
483488

484489
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
485490
intermediate_size_per_partition: int,
@@ -934,7 +939,9 @@ def apply(
934939
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
935940
assert (renormalize and use_grouped_topk
936941
and custom_routing_function is None)
937-
result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
942+
e_score_correction_bias = (e_score_correction_bias.to(
943+
x.dtype) if e_score_correction_bias is not None else None)
944+
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
938945
routing_logits=router_logits.to(torch.float32),
939946
routing_bias=e_score_correction_bias,
940947
x=x,

vllm/utils/deep_gemm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def is_deep_gemm_supported() -> bool:
2727
is_supported_arch = current_platform.is_cuda() and (
2828
current_platform.is_device_capability(90)
2929
or current_platform.is_device_capability(100))
30-
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
30+
return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
31+
and not envs.VLLM_USE_FLASHINFER_MOE_FP8)
3132

3233

3334
@functools.cache
@@ -46,6 +47,10 @@ def is_deep_gemm_e8m0_used() -> bool:
4647
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
4748
return False
4849

50+
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
51+
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
52+
return False
53+
4954
if current_platform.is_device_capability(100) and \
5055
envs.VLLM_USE_DEEP_GEMM_E8M0:
5156
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")

0 commit comments

Comments
 (0)