|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 4 | +""" |
| 5 | +Benchmark comparing old vs new default fused MoE configs. |
| 6 | +
|
| 7 | +Runs the triton fused_moe kernel with three configurations for each scenario: |
| 8 | + 1. Tuned config (from JSON file, if available) — the target to match |
| 9 | + 2. Old default (the hardcoded defaults before this change) |
| 10 | + 3. New default (the improved defaults) |
| 11 | +
|
| 12 | +Usage: |
| 13 | + python benchmarks/kernels/benchmark_moe_defaults.py |
| 14 | +
|
| 15 | +Produces a table showing kernel time (us) and speedup of new vs old defaults. |
| 16 | +""" |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +from vllm.model_executor.layers.fused_moe import fused_topk, override_config |
| 21 | +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig |
| 22 | +from vllm.model_executor.layers.fused_moe.fused_moe import ( |
| 23 | + fused_experts, |
| 24 | + get_default_config, |
| 25 | + get_moe_configs, |
| 26 | +) |
| 27 | +from vllm.platforms import current_platform |
| 28 | +from vllm.triton_utils import triton |
| 29 | +from vllm.utils.torch_utils import set_random_seed |
| 30 | + |
| 31 | +FP8_DTYPE = current_platform.fp8_dtype() |
| 32 | + |
| 33 | + |
| 34 | +def old_default_config(M, E, N, K, topk, dtype=None, block_shape=None): |
| 35 | + """The original defaults before https://github.com/vllm-project/vllm/pull/34846, |
| 36 | + for comparison.""" |
| 37 | + if dtype == "fp8_w8a8" and block_shape is not None: |
| 38 | + return { |
| 39 | + "BLOCK_SIZE_M": 64, |
| 40 | + "BLOCK_SIZE_N": block_shape[0], |
| 41 | + "BLOCK_SIZE_K": block_shape[1], |
| 42 | + "GROUP_SIZE_M": 32, |
| 43 | + "SPLIT_K": 1, |
| 44 | + "num_warps": 4, |
| 45 | + "num_stages": 3 if not current_platform.is_rocm() else 2, |
| 46 | + } |
| 47 | + elif M <= E: |
| 48 | + return { |
| 49 | + "BLOCK_SIZE_M": 16, |
| 50 | + "BLOCK_SIZE_N": 32, |
| 51 | + "BLOCK_SIZE_K": 64, |
| 52 | + "GROUP_SIZE_M": 1, |
| 53 | + "SPLIT_K": 1, |
| 54 | + } |
| 55 | + else: |
| 56 | + return { |
| 57 | + "BLOCK_SIZE_M": 64, |
| 58 | + "BLOCK_SIZE_N": 64, |
| 59 | + "BLOCK_SIZE_K": 32, |
| 60 | + "GROUP_SIZE_M": 8, |
| 61 | + "SPLIT_K": 1, |
| 62 | + } |
| 63 | + |
| 64 | + |
| 65 | +def benchmark_config( |
| 66 | + config, |
| 67 | + M, |
| 68 | + E, |
| 69 | + N, |
| 70 | + K, |
| 71 | + topk, |
| 72 | + dtype, |
| 73 | + use_fp8=False, |
| 74 | + block_shape=None, |
| 75 | + num_iters=100, |
| 76 | +): |
| 77 | + """Time a single kernel config. Returns kernel time in microseconds.""" |
| 78 | + init_dtype = torch.float16 if use_fp8 else dtype |
| 79 | + |
| 80 | + a = torch.randn(M, K, device="cuda", dtype=init_dtype) / 10 |
| 81 | + w1 = torch.randn(E, 2 * N, K, device="cuda", dtype=init_dtype) / 10 |
| 82 | + w2 = torch.randn(E, K, N, device="cuda", dtype=init_dtype) / 10 |
| 83 | + |
| 84 | + w1_scale = None |
| 85 | + w2_scale = None |
| 86 | + a1_scale = None |
| 87 | + a2_scale = None |
| 88 | + if use_fp8: |
| 89 | + if block_shape is not None: |
| 90 | + bsn, bsk = block_shape |
| 91 | + n_tiles_w1 = triton.cdiv(2 * N, bsn) |
| 92 | + k_tiles_w1 = triton.cdiv(K, bsk) |
| 93 | + n_tiles_w2 = triton.cdiv(K, bsn) |
| 94 | + k_tiles_w2 = triton.cdiv(N, bsk) |
| 95 | + w1_scale = torch.rand( |
| 96 | + E, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32 |
| 97 | + ) |
| 98 | + w2_scale = torch.rand( |
| 99 | + E, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32 |
| 100 | + ) |
| 101 | + else: |
| 102 | + w1_scale = torch.rand(E, device="cuda", dtype=torch.float32) |
| 103 | + w2_scale = torch.rand(E, device="cuda", dtype=torch.float32) |
| 104 | + a1_scale = torch.rand(1, device="cuda", dtype=torch.float32) |
| 105 | + a2_scale = torch.rand(1, device="cuda", dtype=torch.float32) |
| 106 | + # Only weights are stored in fp8; activations stay in bf16/fp16 |
| 107 | + # and get dynamically quantized inside the kernel. |
| 108 | + w1 = w1.to(FP8_DTYPE) |
| 109 | + w2 = w2.to(FP8_DTYPE) |
| 110 | + |
| 111 | + quant_config = FusedMoEQuantConfig.make( |
| 112 | + quant_dtype=torch.float8_e4m3fn if use_fp8 else None, |
| 113 | + w1_scale=w1_scale, |
| 114 | + w2_scale=w2_scale, |
| 115 | + a1_scale=a1_scale, |
| 116 | + a2_scale=a2_scale, |
| 117 | + block_shape=block_shape, |
| 118 | + ) |
| 119 | + |
| 120 | + gating = torch.randn(M, E, device="cuda", dtype=torch.float32) |
| 121 | + |
| 122 | + # Warmup |
| 123 | + for _ in range(20): |
| 124 | + with override_config(config): |
| 125 | + topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True) |
| 126 | + fused_experts( |
| 127 | + a, |
| 128 | + w1, |
| 129 | + w2, |
| 130 | + topk_weights, |
| 131 | + topk_ids, |
| 132 | + quant_config=quant_config, |
| 133 | + ) |
| 134 | + torch.cuda.synchronize() |
| 135 | + |
| 136 | + # Benchmark |
| 137 | + start = torch.cuda.Event(enable_timing=True) |
| 138 | + end = torch.cuda.Event(enable_timing=True) |
| 139 | + start.record() |
| 140 | + for _ in range(num_iters): |
| 141 | + with override_config(config): |
| 142 | + topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True) |
| 143 | + fused_experts( |
| 144 | + a, |
| 145 | + w1, |
| 146 | + w2, |
| 147 | + topk_weights, |
| 148 | + topk_ids, |
| 149 | + quant_config=quant_config, |
| 150 | + ) |
| 151 | + end.record() |
| 152 | + torch.cuda.synchronize() |
| 153 | + return start.elapsed_time(end) / num_iters * 1000 # ms -> us |
| 154 | + |
| 155 | + |
| 156 | +# Model configurations: (name, E, N, K, topk, dtype_str, use_fp8, block_shape) |
| 157 | +# N = moe_intermediate_size // tp_size (the value used in config file lookup) |
| 158 | +MODELS = [ |
| 159 | + # --- Few experts --- |
| 160 | + ("Mixtral bf16", 8, 7168, 4096, 2, None, False, None), |
| 161 | + ("Mixtral fp8", 8, 7168, 4096, 2, "fp8_w8a8", True, None), |
| 162 | + # --- Many experts: real model shapes at tp=1 --- |
| 163 | + # Qwen2-MoE-57B: E=60, topk=4, N=1408, K=2048 |
| 164 | + ("Qwen2-MoE bf16", 60, 1408, 2048, 4, None, False, None), |
| 165 | + # DeepSeek-V2: E=64, topk=6, N=1407, K=4096 |
| 166 | + # (use 1408 to avoid odd alignment; real model is 1407) |
| 167 | + ("DeepSeek-V2 bf16", 64, 1408, 4096, 6, None, False, None), |
| 168 | + # OLMoE-7B: E=64, topk=8, N=2048, K=2048 |
| 169 | + ("OLMoE bf16", 64, 2048, 2048, 8, None, False, None), |
| 170 | + # GLM-4-100B-A10B: E=128, topk=8, N=1408, K=4096 |
| 171 | + ("GLM-4-MoE bf16", 128, 1408, 4096, 8, None, False, None), |
| 172 | + # Qwen3-30B-A3B: E=128, topk=8, N=768, K=2048 |
| 173 | + ("Qwen3-MoE bf16", 128, 768, 2048, 8, None, False, None), |
| 174 | + # DeepSeek-V3 / MiMo-V2-Flash: E=256, topk=8, N=2048, K=7168 |
| 175 | + ("DeepSeek-V3 bf16", 256, 2048, 7168, 8, None, False, None), |
| 176 | + # Qwen3.5-70B-A22B (Qwen3-Next): E=512, topk=10, N=512, K=2048 |
| 177 | + ("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None), |
| 178 | + # E=128 N=1856 bf16 |
| 179 | + ("E128 N1856 bf16", 128, 1856, 4096, 8, None, False, None), |
| 180 | + # E=256 N=512 bf16 (DS-V3 tp=4) |
| 181 | + ("DS-V3 tp4 bf16", 256, 512, 7168, 8, None, False, None), |
| 182 | + # E=512 N=512 bf16 (Qwen3-Next tp=1) |
| 183 | + ("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None), |
| 184 | + # E=512 N=256 bf16 (Qwen3-Next tp=2) |
| 185 | + ("Qwen3-Next tp2", 512, 256, 2048, 10, None, False, None), |
| 186 | + # --- FP8 block quant (many experts) --- |
| 187 | + # DS-V3 tp=4: E=256, N=512, fp8 block |
| 188 | + ("DS-V3 tp4 fp8blk", 256, 512, 7168, 8, "fp8_w8a8", True, [128, 128]), |
| 189 | + # DS-V3 tp=8: E=256, N=256, fp8 block |
| 190 | + ("DS-V3 tp8 fp8blk", 256, 256, 7168, 8, "fp8_w8a8", True, [128, 128]), |
| 191 | + # Qwen3-Next tp=2 fp8 block |
| 192 | + ("Qwen3-Next tp2 fp8blk", 512, 256, 2048, 10, "fp8_w8a8", True, [128, 128]), |
| 193 | +] |
| 194 | + |
| 195 | +BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] |
| 196 | + |
| 197 | + |
| 198 | +def main(): |
| 199 | + set_random_seed(0) |
| 200 | + torch.set_default_device("cuda") |
| 201 | + dtype = torch.bfloat16 |
| 202 | + |
| 203 | + for name, E, N, K, topk, dtype_str, use_fp8, block_shape in MODELS: |
| 204 | + print(f"\n{'=' * 90}") |
| 205 | + print(f" {name} (E={E}, N={N}, K={K}, topk={topk})") |
| 206 | + print(f"{'=' * 90}") |
| 207 | + |
| 208 | + # Try to load tuned config |
| 209 | + block_n = block_shape[0] if block_shape else None |
| 210 | + block_k = block_shape[1] if block_shape else None |
| 211 | + tuned = get_moe_configs(E, N, dtype_str, block_n, block_k) |
| 212 | + has_tuned = tuned is not None |
| 213 | + print(f" Tuned config available: {has_tuned}") |
| 214 | + |
| 215 | + hdr = ( |
| 216 | + f"{'Batch':>6} | {'Tuned (us)':>11} | {'Old (us)':>11} | " |
| 217 | + f"{'New (us)':>11} | {'New/Old':>8} | {'New/Tuned':>10}" |
| 218 | + ) |
| 219 | + print(f" {hdr}") |
| 220 | + print(f" {'-' * len(hdr)}") |
| 221 | + |
| 222 | + for M in BATCH_SIZES: |
| 223 | + old_cfg = old_default_config(M, E, N, K, topk, dtype_str, block_shape) |
| 224 | + new_cfg = get_default_config(M, E, N, K, topk, dtype_str, block_shape) |
| 225 | + |
| 226 | + if has_tuned: |
| 227 | + tuned_cfg = tuned[min(tuned.keys(), key=lambda x: abs(x - M))] |
| 228 | + t_tuned = benchmark_config( |
| 229 | + tuned_cfg, |
| 230 | + M, |
| 231 | + E, |
| 232 | + N, |
| 233 | + K, |
| 234 | + topk, |
| 235 | + dtype, |
| 236 | + use_fp8=use_fp8, |
| 237 | + block_shape=block_shape, |
| 238 | + ) |
| 239 | + else: |
| 240 | + t_tuned = None |
| 241 | + |
| 242 | + t_old = benchmark_config( |
| 243 | + old_cfg, |
| 244 | + M, |
| 245 | + E, |
| 246 | + N, |
| 247 | + K, |
| 248 | + topk, |
| 249 | + dtype, |
| 250 | + use_fp8=use_fp8, |
| 251 | + block_shape=block_shape, |
| 252 | + ) |
| 253 | + t_new = benchmark_config( |
| 254 | + new_cfg, |
| 255 | + M, |
| 256 | + E, |
| 257 | + N, |
| 258 | + K, |
| 259 | + topk, |
| 260 | + dtype, |
| 261 | + use_fp8=use_fp8, |
| 262 | + block_shape=block_shape, |
| 263 | + ) |
| 264 | + |
| 265 | + ratio_new_old = t_new / t_old |
| 266 | + tuned_str = f"{t_tuned:11.2f}" if t_tuned else f"{'N/A':>11}" |
| 267 | + ratio_tuned = f"{t_new / t_tuned:10.2f}x" if t_tuned else f"{'N/A':>10}" |
| 268 | + # flag regressions where new default is >5% slower than old |
| 269 | + marker = " <--" if ratio_new_old > 1.05 else "" |
| 270 | + |
| 271 | + print( |
| 272 | + f" {M:>6} | {tuned_str} | {t_old:11.2f} | {t_new:11.2f} " |
| 273 | + f"| {ratio_new_old:7.2f}x | {ratio_tuned}{marker}" |
| 274 | + ) |
| 275 | + |
| 276 | + |
| 277 | +if __name__ == "__main__": |
| 278 | + main() |
0 commit comments