|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
| 4 | +from enum import Enum |
4 | 5 | from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
5 | 6 |
|
6 | 7 | import torch
|
|
68 | 69 | logger = init_logger(__name__)
|
69 | 70 |
|
70 | 71 |
|
| 72 | +class Fp8MoeBackend(Enum): |
| 73 | + NONE = 0 |
| 74 | + FLASHINFER_TRTLLM = 1 |
| 75 | + FLASHINFER_CUTLASS = 2 |
| 76 | + DEEPGEMM = 3 |
| 77 | + CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4 |
| 78 | + MARLIN = 5 |
| 79 | + TRITON = 6 |
| 80 | + |
| 81 | + |
| 82 | +def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: |
| 83 | + """ |
| 84 | + Select the primary FP8 MoE backend |
| 85 | + Note: Shape-specific fallbacks may still occur at runtime. |
| 86 | + """ |
| 87 | + # prefer FlashInfer backends when available and enabled on supported GPUs |
| 88 | + if (current_platform.is_cuda() |
| 89 | + and current_platform.is_device_capability(100) |
| 90 | + and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): |
| 91 | + backend = get_flashinfer_moe_backend() |
| 92 | + if backend == FlashinferMoeBackend.TENSORRT_LLM: |
| 93 | + logger.info_once( |
| 94 | + "Using FlashInfer FP8 MoE TRTLLM backend for SM100") |
| 95 | + return Fp8MoeBackend.FLASHINFER_TRTLLM |
| 96 | + else: |
| 97 | + logger.info_once( |
| 98 | + "Using FlashInfer FP8 MoE CUTLASS backend for SM100") |
| 99 | + return Fp8MoeBackend.FLASHINFER_CUTLASS |
| 100 | + |
| 101 | + # weight-only path for older GPUs without native FP8 |
| 102 | + use_marlin = (not current_platform.has_device_capability(89) |
| 103 | + or envs.VLLM_TEST_FORCE_FP8_MARLIN) |
| 104 | + if current_platform.is_rocm(): |
| 105 | + use_marlin = False |
| 106 | + if use_marlin: |
| 107 | + logger.info_once("Using Marlin backend for FP8 MoE") |
| 108 | + return Fp8MoeBackend.MARLIN |
| 109 | + |
| 110 | + # deepGEMM on supported platforms with block-quantized weights |
| 111 | + if envs.VLLM_USE_DEEP_GEMM and block_quant: |
| 112 | + if not has_deep_gemm(): |
| 113 | + logger.warning_once( |
| 114 | + "DeepGEMM backend requested but not available.") |
| 115 | + elif is_deep_gemm_supported(): |
| 116 | + logger.info_once("Using DeepGEMM backend for FP8 MoE") |
| 117 | + return Fp8MoeBackend.DEEPGEMM |
| 118 | + |
| 119 | + # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights |
| 120 | + if (current_platform.is_cuda() |
| 121 | + and current_platform.is_device_capability(100) and block_quant): |
| 122 | + logger.info_once( |
| 123 | + "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") |
| 124 | + return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM |
| 125 | + |
| 126 | + # default to Triton |
| 127 | + logger.info_once("Using Triton backend for FP8 MoE") |
| 128 | + return Fp8MoeBackend.TRITON |
| 129 | + |
| 130 | + |
71 | 131 | class Fp8Config(QuantizationConfig):
|
72 | 132 | """Config class for FP8."""
|
73 | 133 |
|
@@ -453,54 +513,19 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
453 | 513 | self.fused_experts: Optional[
|
454 | 514 | mk.FusedMoEModularKernel] = None # type: ignore
|
455 | 515 |
|
456 |
| - # For GPUs that lack FP8 hardware support, we can leverage the Marlin |
457 |
| - # kernel for fast weight-only FP8 quantization |
458 |
| - self.use_marlin = (not current_platform.has_device_capability(89) |
459 |
| - or envs.VLLM_TEST_FORCE_FP8_MARLIN) |
460 |
| - # Disable marlin for rocm |
461 |
| - if current_platform.is_rocm(): |
462 |
| - self.use_marlin = False |
| 516 | + self.fp8_backend = get_fp8_moe_backend(self.block_quant) |
463 | 517 |
|
464 |
| - # First check for Flashinfer MOE on Blackwell GPUs |
| 518 | + self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN) |
465 | 519 | self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
466 |
| - if (current_platform.is_cuda() |
467 |
| - and current_platform.is_device_capability(100) |
468 |
| - and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): |
469 |
| - self.flashinfer_moe_backend = get_flashinfer_moe_backend() |
470 |
| - logger.info_once( |
471 |
| - f"Detected Blackwell GPUs, using FlashInfer " |
472 |
| - f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.") |
473 |
| - |
474 |
| - # Check for DeepGemm support. |
475 |
| - self.allow_deep_gemm = False |
476 |
| - if envs.VLLM_USE_DEEP_GEMM: |
477 |
| - if not has_deep_gemm(): |
478 |
| - logger.warning_once("Failed to import DeepGemm kernels.") |
479 |
| - elif not self.block_quant: |
480 |
| - logger.warning_once("Model is not block quantized. Not using" |
481 |
| - " DeepGemm kernels") |
482 |
| - elif self.flashinfer_moe_backend: |
483 |
| - logger.info_once("DeepGemm disabled: FlashInfer MOE is" |
484 |
| - " enabled.") |
485 |
| - elif (is_deep_gemm_supported()): |
486 |
| - logger.debug_once( |
487 |
| - "DeepGemm kernels available for Fp8MoEMethod.") |
488 |
| - self.allow_deep_gemm = True |
489 |
| - else: |
490 |
| - logger.warning_once( |
491 |
| - "DeepGemm not supported on the current platform.") |
492 |
| - |
493 |
| - # Check for CutlassBlockScaledGroupedGemm support. |
494 |
| - self.allow_cutlass_block_scaled_grouped_gemm = False |
495 |
| - if not self.block_quant: |
496 |
| - logger.debug_once("Model is not block quantized. Not using " |
497 |
| - "CutlassBlockScaledGroupedGemm kernels") |
498 |
| - elif (current_platform.is_cuda() |
499 |
| - and current_platform.is_device_capability(100) |
500 |
| - and not self.flashinfer_moe_backend): |
501 |
| - logger.debug_once( |
502 |
| - "CutlassBlockScaledGroupedGemm available for Fp8MoEMethod.") |
503 |
| - self.allow_cutlass_block_scaled_grouped_gemm = True |
| 520 | + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: |
| 521 | + self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM |
| 522 | + elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: |
| 523 | + self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS |
| 524 | + |
| 525 | + self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM) |
| 526 | + self.allow_cutlass_block_scaled_grouped_gemm = ( |
| 527 | + self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM |
| 528 | + ) |
504 | 529 |
|
505 | 530 | def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
506 | 531 | intermediate_size_per_partition: int,
|
|
0 commit comments