@@ -434,14 +434,9 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
434
434
self .weight_block_size = self .quant_config .weight_block_size
435
435
self .block_quant = self .weight_block_size is not None
436
436
437
- self .flashinfer_moe_backend : Optional [FlashinferMoeBackend ] = None
438
437
self .fused_experts : Optional [
439
438
mk .FusedMoEModularKernel ] = None # type: ignore
440
- if envs .VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe ():
441
- self .flashinfer_moe_backend = get_flashinfer_moe_backend ()
442
- logger .info_once (
443
- f"Using FlashInfer { self .flashinfer_moe_backend .value } kernels"
444
- )
439
+
445
440
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
446
441
# kernel for fast weight-only FP8 quantization
447
442
self .use_marlin = (not current_platform .has_device_capability (89 )
@@ -450,14 +445,27 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
450
445
if current_platform .is_rocm ():
451
446
self .use_marlin = False
452
447
448
+ # First check for Flashinfer MOE on Blackwell GPUs
449
+ self .flashinfer_moe_backend : Optional [FlashinferMoeBackend ] = None
450
+ if (current_platform .is_cuda ()
451
+ and current_platform .is_device_capability (100 )
452
+ and envs .VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe ()):
453
+ self .flashinfer_moe_backend = get_flashinfer_moe_backend ()
454
+ logger .info_once (
455
+ f"Detected Blackwell GPUs, using FlashInfer "
456
+ f"{ self .flashinfer_moe_backend .value } kernels for FP8 MOE." )
457
+
453
458
# Check for DeepGemm support.
454
459
self .allow_deep_gemm = False
455
460
if envs .VLLM_USE_DEEP_GEMM :
456
461
if not has_deep_gemm ():
457
462
logger .warning_once ("Failed to import DeepGemm kernels." )
458
463
elif not self .block_quant :
459
- logger .warning_once ("Model is not block quantized. Not using "
460
- "DeepGemm kernels" )
464
+ logger .warning_once ("Model is not block quantized. Not using"
465
+ " DeepGemm kernels" )
466
+ elif self .flashinfer_moe_backend :
467
+ logger .info_once ("DeepGemm disabled: FlashInfer MOE is"
468
+ " enabled." )
461
469
elif (is_deep_gemm_supported ()):
462
470
logger .info_once ("Using DeepGemm kernels for Fp8MoEMethod." )
463
471
self .allow_deep_gemm = True
@@ -471,15 +479,12 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
471
479
logger .debug_once ("Model is not block quantized. Not using "
472
480
"CutlassBlockScaledGroupedGemm kernels" )
473
481
elif (current_platform .is_cuda ()
474
- and current_platform .is_device_capability (100 )):
482
+ and current_platform .is_device_capability (100 )
483
+ and not self .flashinfer_moe_backend ):
475
484
logger .info_once (
476
- "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod. "
477
- )
485
+ "Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE "
486
+ "on SM100." )
478
487
self .allow_cutlass_block_scaled_grouped_gemm = True
479
- else :
480
- logger .warning_once (
481
- "CutlassBlockScaledGroupedGemm not supported on the current "
482
- "platform." )
483
488
484
489
def create_weights (self , layer : Module , num_experts : int , hidden_size : int ,
485
490
intermediate_size_per_partition : int ,
@@ -934,7 +939,9 @@ def apply(
934
939
import vllm .model_executor .layers .fused_moe .flashinfer_trtllm_moe # noqa: E501, F401
935
940
assert (renormalize and use_grouped_topk
936
941
and custom_routing_function is None )
937
- result = torch .ops .vllm .flashinfer_fused_moe_blockscale_fp8 (
942
+ e_score_correction_bias = (e_score_correction_bias .to (
943
+ x .dtype ) if e_score_correction_bias is not None else None )
944
+ return torch .ops .vllm .flashinfer_fused_moe_blockscale_fp8 (
938
945
routing_logits = router_logits .to (torch .float32 ),
939
946
routing_bias = e_score_correction_bias ,
940
947
x = x ,
0 commit comments