From 3383f88b0bba5b76a9e601b6446ec0c7c2159a89 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 09:15:38 +0000 Subject: [PATCH 01/13] Updated benchmark --- .../third_party/vllm/batched_moe_benchmark.py | 649 ++++++++++++++++++ 1 file changed, 649 insertions(+) create mode 100644 benchmarks/third_party/vllm/batched_moe_benchmark.py diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py new file mode 100644 index 0000000000..ce63b02a62 --- /dev/null +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -0,0 +1,649 @@ +# pylint: skip-file +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Batched MoE benchmark +===================== + +This benchmark is based on the test_batched_moe.py tests and follows +the framework from gemm_benchmark.py to compare performance of different +batched MoE implementations using vLLM kernels. + +""" +from typing import Optional +import os + +import torch +import triton +import triton.language as tl + +import triton_kernels_benchmark as benchmark_suite + +# Import vLLM MoE functions +from vllm.model_executor.layers.fused_moe.fused_batched_moe import invoke_moe_batched_triton_kernel +from vllm.platforms import current_platform +from vllm.model_executor.layers.fused_moe.utils import normalize_batched_scales_shape + +# Import utility functions from vLLM tests +from tests.kernels.moe.utils import make_quantized_test_activations, make_test_weights +from tests.kernels.quant_utils import native_batched_masked_quant_matmul + + +@triton.jit +def moe_mmk( + a_desc, + b_desc, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # Offsets and masks + offs_m, + offs_n, + offs_bn, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + pid_m, + pid_n, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, +): + + if use_w8a16: + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn + + # per act token + elif per_act_token_quant: + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] + + b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B using tensor descriptors + a = a_desc.load([pid_m * BLOCK_M, k * BLOCK_K]) + b = b_desc.load([k * BLOCK_K, pid_n * BLOCK_N]) + + # We accumulate along the K dimension. + if use_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, mask=mask_m, other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + + if use_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + return accumulator + + +@triton.jit +def expert_triton_kernel( + a_desc, #[max_tokens, K] + b_desc, #[K, N] + c_desc, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + # strides + stride_ak: tl.int64, + stride_bk: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # offsets + offs_bn, + # Blockwise quantization data + group_n, + group_k, + pid_m, + pid_n, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) % N + # offs_k = tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + accumulator = moe_mmk( + a_desc, b_desc, K, expert_id, a_scale_ptr, b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn, + # Offsets and masks + offs_m, offs_n, offs_bn, mask_m, + # Block size for block-wise quantization + group_n, group_k, pid_m, pid_n, + # Meta-parameters + BLOCK_M, BLOCK_N, BLOCK_K, compute_type, use_fp8_w8a8, use_int8_w8a16, per_act_token_quant) + + # store in C + # offs_cn = tl.arange(0, BLOCK_N) + # c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn + # c_mask = mask_m[:, None] & (offs_cn[None, :] < N) + c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], accumulator) + # tl.store(c_ptrs, accumulator, mask=c_mask) + + +# def get_matmul_batched_autotune_configs(): +# configs = [ +# triton.Config( +# {'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, +# num_stages=s, num_warps=32) for s in [2, 3] +# ] + [ +# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, +# num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)]) +# ] + [ +# triton.Config( +# {'BLOCK_M': 128, 'BLOCK_N': 1024, 'BLOCK_K': 16, 'grf_mode': 'large'}, +# num_stages=s, num_warps=32) for s in [2, 3] +# ] + [ +# triton.Config( +# {'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, +# num_stages=s, num_warps=32) for s in [2] +# ] + [ +# triton.Config( +# {'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, +# num_stages=s, num_warps=32) for s in [2] +# ] + [ +# triton.Config( +# {'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, +# num_stages=s, num_warps=4) for s in [2] +# ] +# return configs + + +# @triton.autotune( +# configs=get_matmul_batched_autotune_configs(), +# key=['max_num_tokens', 'K', 'N'] +# ) +@triton.jit +def batched_triton_kernel( + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_be: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_ce: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_ase: tl.constexpr, + stride_asm: tl.constexpr, + stride_ask: tl.constexpr, + stride_bse: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # Early exit + return + + # axis 1 is M_blocks * N_blocks + pid_mn = tl.program_id(axis=1) + #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid_mn // num_pid_n + pid_n = pid_mn % num_pid_n + + cta_m_start = pid_m * BLOCK_M + cta_n_start = pid_n * BLOCK_N + if cta_m_start >= e_num_tokens: + # Early exit + return + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_n_size = min(BLOCK_N, N - cta_n_start) + + # M = M + a_desc = tl.make_tensor_descriptor(base=a_ptr + expert_id * stride_ae, shape=(e_num_tokens, K), + strides=(stride_am, stride_ak), block_shape=(BLOCK_M, BLOCK_K)) + # b_desc = tl.make_tensor_descriptor(base=b_ptr + expert_id * stride_be, shape=(N, K), strides=(stride_bn, stride_bk), + # block_shape=(BLOCK_N, BLOCK_K)) + b_desc = tl.make_tensor_descriptor(base=b_ptr + expert_id * stride_be, shape=(K, N), strides=(stride_bk, stride_bn), + block_shape=(BLOCK_K, BLOCK_N)) + c_desc = tl.make_tensor_descriptor(base=c_ptr + expert_id * stride_ce, shape=(e_num_tokens, N), + strides=(stride_cm, stride_cn), block_shape=(BLOCK_M, BLOCK_N)) + + # a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am + # b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn + # c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + + # cta_n_start * stride_cn) + + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N + + if use_fp8_w8a8: + a_scale_ptr = a_scale_ptr + expert_id * stride_ase + b_scale_ptr = b_scale_ptr + expert_id * stride_bse + + # block-wise + if group_k > 0 and group_n > 0 or per_act_token_quant: + a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm + + expert_triton_kernel(a_desc, b_desc, c_desc, expert_id, compute_type, cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, b_scale_ptr, + # Strides + stride_ak, stride_bk, stride_ase, stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn, + # offsets + offs_bn, + # Blockwise quantization data + group_n, group_k, pid_m, pid_n, + # Quantization schemes + use_fp8_w8a8, use_int8_w8a16, per_act_token_quant, + # Kernel config + BLOCK_M, BLOCK_N, BLOCK_K) + + +def invoke_moe_batched_triton_kernel_td(A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, N, K] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + config: dict[str, int], per_act_token_quant: bool, + block_shape: Optional[list[int]] = None): + assert not use_int4_w4a16 + max_num_tokens = A.size(1) + K = A.size(2) + N = C.size(2) + + BLOCK_M = config['BLOCK_SIZE_M'] + BLOCK_N = config['BLOCK_SIZE_N'] + BLOCK_K = config['BLOCK_SIZE_K'] + BLOCK_M = 256 + BLOCK_N = 128 + BLOCK_K = 32 + num_warps = 64 + # BLOCK_M = 16 + # BLOCK_N = 16 + # BLOCK_K = 16 + # num_warps = 4 + + grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) + + A_scale = normalize_batched_scales_shape(A_scale, expert_num_tokens.shape[0]) + + if B_scale is not None and B_scale.ndim == 1: + assert B_scale.numel() == expert_num_tokens.shape[0] + B_scale = B_scale.view(-1, 1, 1) + + assert A_scale is None or A_scale.ndim == 3, (f'{0 if A_scale is None else A_scale.shape}') + assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, (f'{0 if B_scale is None else B_scale.shape}') + + if B_scale is not None: + if B_scale.ndim == 1: + stride_bse = 1 + stride_bsk = 0 + stride_bsn = 0 + else: + stride_bse = B_scale.stride(0) + stride_bsk = B_scale.stride(2) + stride_bsn = B_scale.stride(1) + + else: + stride_bse = 0 + stride_bsk = 0 + stride_bsn = 0 + + if A_scale is not None: + stride_ase = A_scale.stride(0) + stride_asm = A_scale.stride(1) + stride_ask = A_scale.stride(2) + else: + stride_ase = 0 + stride_asm = 0 + stride_ask = 0 + + batched_triton_kernel[grid]( + A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + per_act_token_quant, + # Kernel config + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + num_warps=num_warps, + ) + + +# Benchmark shapes for batched MoE (E: num_experts, M: max_tokens_per_expert, K: hidden_dim, N: intermediate_dim) +BATCHED_MM_X_VALS = [(E, M, K, N) for E in [8, 32] for M in [32, 224, 512] for K in [128, 1024] for N in [128, 1024]] +BATCHED_MM_X_VALS = [ + # (256, 16, 7168, 2048 * 2), + # (256, 16, 7168, 2048), + # (256, 16, 7168, 2048), + # (256, 2, 5000, 2024), + # (256, 16, 2048, 7168), + *[(E, M, K, N) for E in [8, 32] for M in [32, 224, 512] for K in [128, 1024] for N in [128, 1024]] +] + +DEVICE_NAME = torch.xpu.get_device_name() +DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory + + +def is_enough_memory(x_val): + # x_val: (E, M, K, N) + E, M, K, N = x_val + # A: (E, M, K) bfloat16 + # B: (E, K, N) bfloat16 + # C: (E, M, N) float32 + # num_expert_tokens: (E,) int32 + required_memory = E * M * K * 2 + E * K * N * 2 + E * M * N * 4 + E * 4 + enough_memory = required_memory < DEVICE_TOTAL_MEMORY + if not enough_memory: + print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}") + return enough_memory + + +BATCHED_MM_X_VALS = [x_val for x_val in BATCHED_MM_X_VALS if is_enough_memory(x_val)] + + +def get_batched_mm_benchmark( + providers_filter: Optional[list[str]] = None, + dtype: torch.dtype = torch.bfloat16, + use_fp8_w8a8: bool = False, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, + plot_name: str = 'batched-mm-performance', +): + """ + Returns a Mark object containing a Benchmark object for batched matrix multiplication. + """ + supported_providers = { + 'triton': 'triton', + 'triton-td': 'triton-td', + 'pytorch': 'pytorch', + } + + providers = benchmark_suite.filter_providers(supported_providers, providers_filter) + + # Set up quantization + if use_fp8_w8a8: + act_dtype = torch.bfloat16 + quant_dtype = torch.float8_e4m3fn + else: + act_dtype = dtype + quant_dtype = None + + @benchmark_suite.perf_report( + benchmark_suite.Benchmark( + x_names=['num_experts', 'max_tokens_per_expert', 'K', 'N'], + x_vals=BATCHED_MM_X_VALS, + line_arg='provider', + line_vals=list(providers.keys()), + line_names=list(providers.values()), + styles=[('green', '-'), ('blue', '--'), ('red', ':')], + ylabel=['GB/s', 'TFlops'], + plot_name=plot_name, + args={}, + )) + def benchmark(num_experts, max_tokens_per_expert, K, N, provider): + current_platform.seed_everything(70) + n_warmup = 300 + + # Create random number of expert tokens + num_expert_tokens = torch.randint(low=0, high=max_tokens_per_expert + 1, size=(num_experts, ), device='xpu', + dtype=torch.int32) + out_shape = (num_experts, max_tokens_per_expert, N) + + # Create quantized test activations + A, A_q, A_scale = make_quantized_test_activations( + num_experts, + max_tokens_per_expert, + K, + in_dtype=act_dtype, + quant_dtype=quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + ) + + # Create test weights (only need B matrix for batched MM) + (B, B_q, B_scale, _), _ = make_test_weights( + num_experts, + N // 2, + K, + in_dtype=act_dtype, + quant_dtype=quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + ) + # A_q[:] = 0 + # A_q[:, :, :] = 1 + + # B_q[:] = 0 + # B_q[:, 0, 0] = 1 + + # A_q[:] = 0 + # A_q[:, 0, 0] = 1 + + # B_q[:] = 0 + # B_q[:, 0, 0] = 0 + + quantiles = [0.5, 0.0, 1.0] + + C = torch.zeros(out_shape, device='xpu', dtype=act_dtype) + compute_tl_dtype = {torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32}[C.dtype] + rtol = 6e-2 if act_dtype == torch.bfloat16 else 1e-2 + atol = 6e-2 if act_dtype == torch.bfloat16 else 1e-2 + ref = torch.zeros(out_shape, device='xpu', dtype=act_dtype) + + def torch_fn(): + native_batched_masked_quant_matmul(A_q, B_q, ref, num_expert_tokens, A_scale, B_scale, block_shape, + per_act_token_quant) + return ref + + if provider == 'pytorch': + # PyTorch reference implementation using native_batched_masked_quant_matmul + _, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench( + torch_fn, + n_warmup=n_warmup, + n_repeat=10, + quantiles=quantiles, + ) + + elif provider == 'triton' or provider == 'triton-td': + # Triton batched MoE kernel + invoke_kernel = invoke_moe_batched_triton_kernel if provider == 'triton' else invoke_moe_batched_triton_kernel_td + + # invoke_kernel = invoke_moe_batched_triton_kernel_td + def triton_fn(): + invoke_kernel( + A_q, + B_q, + C, + num_expert_tokens, + compute_tl_dtype, + A_scale, + B_scale, + None, + use_fp8_w8a8, + False, + False, + config={'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 16 if dtype.itemsize > 1 else 32}, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + return C + + # Verify correctness against reference + benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench( + triton_fn, + n_warmup=n_warmup, + n_repeat=10, + quantiles=quantiles, + ) + + else: + raise NotImplementedError(f'Unsupported provider {provider}') + + # Calculate performance metrics + # Memory bandwidth: A (E*M*K*2) + B (E*K*N*2) + C (E*M*N*4) bytes + # Compute: E * M * N * K * 2 FLOPs (multiply-add) + + def gbps(ms): + total_bytes = num_experts * (max_tokens_per_expert * K * 2 + K * N * 2 + max_tokens_per_expert * N * 4) + return total_bytes * (1e-9) / (ms * 1e-3) + + def tflops(ms): + total_flops = num_experts * max_tokens_per_expert * N * K * 2 + return total_flops * (1e-12) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + return benchmark + + +if __name__ == '__main__': + # Run batched MM benchmark + print('Running batched MM benchmark...') + _benchmark_mm = get_batched_mm_benchmark( + dtype=torch.bfloat16, + use_fp8_w8a8=(os.getenv('USE_FP8_W8A8', '0') == '1'), + per_act_token_quant=(os.getenv('PER_ACT_TOKEN_QUANT', '0') == '1'), + plot_name='moe-batched-mm-performance', + ) + _benchmark_mm.run(show_plots=False, print_data=True) From 50ea0e74946ad51ecab4c7fba8fb74a3165468a4 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 11:22:19 +0000 Subject: [PATCH 02/13] Added CI --- .github/workflows/third-party-benchmarks.yml | 23 ++++++++++++++++ scripts/test-triton.sh | 29 ++++++++++++++++---- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index fdb473f391..a3a31e293b 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -87,6 +87,29 @@ jobs: mkdir reports echo "REPORTS=$PWD/reports" >> $GITHUB_ENV + - name: Install benchmarks + if: install-benchmarks + run: | + cd benchmarks + pip install . + pip install intel-pti==0.12.4 + PTI_LIBS_DIR=$(python -c "import sysconfig; print(sysconfig.get_paths()['stdlib']+'/..')") + # the output should contain: `libpti.so`, `libpti_metrics.so.0.12.4` and `libpti_view.so.0.12.4` + ls $PTI_LIBS_DIR + echo "PTI_LIBS_DIR=$PTI_LIBS_DIR" >> $GITHUB_ENV + + - name: Run vllm benchmarks + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} + run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + source ./scripts/capture-hw-details.sh + + ./scripts/test-triton.sh --install-vllm --skip-pip-install --skip-pytorch-install + + cd benchmarks + python third_party/vllm/batched_moe_benchmark.py + + - name: Run Liger-Kernel benchmarks if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'liger-kernel')) }} run: | diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 577e867839..94e3ea675c 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -29,6 +29,7 @@ TEST: --sglang --liger --vllm + --install-vllm OPTION: --unskip @@ -72,6 +73,7 @@ TEST_INDUCTOR=false TEST_SGLANG=false TEST_LIGER=false TEST_VLLM=false +INSTALL_VLLM=false TEST_TRITON_KERNELS=false VENV=false TRITON_TEST_REPORTS=false @@ -203,6 +205,11 @@ while (( $# != 0 )); do TEST_DEFAULT=false shift ;; + --install-vllm) + INSTALL_VLLM=true + TEST_DEFAULT=false + shift + ;; --triton-kernels) TEST_TRITON_KERNELS=true TEST_DEFAULT=false @@ -621,9 +628,9 @@ run_liger_tests() { run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-4} Liger-Kernel/test/ } -run_vllm_tests() { +run_vllm_install() { echo "************************************************" - echo "****** Running VLLM Triton tests ******" + echo "****** Installing VLLM ******" echo "************************************************" if ! [ -d "./vllm" ]; then @@ -642,15 +649,24 @@ run_vllm_tests() { git checkout "$(<../benchmarks/third_party/vllm/vllm-kernels-pin.txt)" sed -i '/pytorch\|torch/d' requirements.txt pip install -r requirements.txt - VLLM_TARGET_DEVICE=xpu pip install -e . + VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -e . cd .. - VLLM_TARGET_DEVICE=xpu pip install --no-deps vllm + VLLM_TARGET_DEVICE=xpu pip install --no-deps --no-build-isolation vllm fi - cd vllm pip install pytest pytest-cov pytest-xdist cachetools cbor2 blake3 pybase64 openai_harmony tblib +} + +run_vllm_tests() { + echo "************************************************" + echo "****** Running VLLM Triton tests ******" + echo "************************************************" + + run_vllm_install + + cd vllm run_pytest_command -vvv tests/kernels/moe/test_batched_moe.py tests/kernels/attention/test_triton_unified_attention.py } @@ -738,6 +754,9 @@ test_triton() { if [ "$TEST_VLLM" == true ]; then run_vllm_tests fi + if [ "$INSTALL_VLLM" == true ]; then + run_vllm_install + fi if [ "$TEST_TRITON_KERNELS" == true ]; then run_triton_kernels_tests fi From 2da38775f6823e232b3ccf35ed567923d127caf5 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 11:23:22 +0000 Subject: [PATCH 03/13] Fixed naming --- .github/workflows/third-party-benchmarks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index a3a31e293b..c84f69c08e 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -32,7 +32,7 @@ env: jobs: build: - name: Triton benchmarks + name: Third party benchmarks runs-on: - linux - ${{ inputs.runner_label || 'max1550' }} From 048ea9a49f35f001d4bd77ff7e8de00d211e95ff Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 11:25:12 +0000 Subject: [PATCH 04/13] Fixed typo --- .github/workflows/third-party-benchmarks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index c84f69c08e..8ea3b11e34 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -88,7 +88,7 @@ jobs: echo "REPORTS=$PWD/reports" >> $GITHUB_ENV - name: Install benchmarks - if: install-benchmarks + id: install-benchmarks run: | cd benchmarks pip install . From bf66ad2edc4d0e989e49f1b1ed2b2c401a3d6ee1 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 11:53:49 +0000 Subject: [PATCH 05/13] Added processing --- .github/workflows/third-party-benchmarks.yml | 3 +- .../third_party/vllm/batched_moe_benchmark.py | 2 +- .../third_party/vllm/transform_results.py | 100 ++++++++++++++++++ 3 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 benchmarks/third_party/vllm/transform_results.py diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 8ea3b11e34..08ca333062 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -107,7 +107,8 @@ jobs: ./scripts/test-triton.sh --install-vllm --skip-pip-install --skip-pytorch-install cd benchmarks - python third_party/vllm/batched_moe_benchmark.py + python third_party/vllm/batched_moe_benchmark.py --reports $REPORTS + python third_party/vllm/transform_results.py $REPORTS/moe-gemm.csv $REPORTS/moe-gemm-report.csv --tag $TAG - name: Run Liger-Kernel benchmarks diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py index ce63b02a62..df2a3b29b9 100644 --- a/benchmarks/third_party/vllm/batched_moe_benchmark.py +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -487,7 +487,7 @@ def get_batched_mm_benchmark( use_fp8_w8a8: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, - plot_name: str = 'batched-mm-performance', + plot_name: str = 'moe-gemm', ): """ Returns a Mark object containing a Benchmark object for batched matrix multiplication. diff --git a/benchmarks/third_party/vllm/transform_results.py b/benchmarks/third_party/vllm/transform_results.py new file mode 100644 index 0000000000..8765768ef5 --- /dev/null +++ b/benchmarks/third_party/vllm/transform_results.py @@ -0,0 +1,100 @@ +import argparse +import os +import uuid +import json +from datetime import datetime + +import pandas as pd + + +def parse_args(): + parser = argparse.ArgumentParser(description='Parse MoE benchmark CSV') + parser.add_argument('source', help='Path to the MoE benchmark CSV file') + parser.add_argument('target', help='Path to output CSV file') + parser.add_argument('--tag', help='Tag for the benchmark run', default='') + parser.add_argument('--model', help='Model name', default='unknown-model') + parser.add_argument('--max-new-tokens', type=int, help='Maximum new tokens', default=128) + parser.add_argument('--batch-size', type=int, help='Batch size', default=1) + + return parser.parse_args() + + +def parse_moe_csv(csv_file_path, tag): + """Parse the MoE benchmark CSV and extract performance metrics.""" + + df = pd.read_csv(csv_file_path) + + run_uuid = uuid.uuid4().hex + current_datetime = datetime.now().isoformat() + + # Create params for all rows vectorized + df['params'] = df.apply( + lambda row: json.dumps({ + 'num_experts': int(row['num_experts']), + 'max_tokens_per_expert': int(row['max_tokens_per_expert']), + 'K': int(row['K']), + 'N': int(row['N']), + }), axis=1) + + # Define compiler columns + compilers = [('triton', 'triton-TFlops'), ('pytorch', 'pytorch-TFlops'), ('triton-td', 'triton-td-TFlops')] + + # Create list of dataframes for each compiler + dfs = [] + for compiler_name, tflops_col in compilers: + if tflops_col in df.columns: + # Filter out NaN values + valid_rows = df[df[tflops_col].notna()].copy() + if len(valid_rows) > 0: + valid_rows['run_uuid'] = run_uuid + valid_rows['ts'] = current_datetime + valid_rows['benchmark_group'] = 'moe-benchmark' + valid_rows['benchmark'] = 'moe-benchmark' + valid_rows['compiler'] = compiler_name + valid_rows['value_name'] = 'tflops' + valid_rows['value'] = valid_rows[tflops_col].astype(float) + valid_rows['tag'] = tag + + # Select only needed columns + result_df = valid_rows[[ + 'run_uuid', 'ts', 'benchmark_group', 'benchmark', 'compiler', 'value_name', 'value', 'params', 'tag' + ]] + dfs.append(result_df) + + # Concatenate all compiler results + df_results = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame() + + host_info = { + n: os.getenv(n.upper(), default='') + for n in [ + 'libigc1_version', + 'level_zero_version', + 'gpu_device', + 'agama_version', + 'torch_version', + 'compiler_version', + 'benchmarking_method', + ] + } + if not host_info['gpu_device']: + raise RuntimeError('Could not find GPU device description, was `capture-hw-details.sh` called?') + + for name, val in host_info.items(): + df_results[name] = val + + print(f'DataFrame shape: {df_results.shape}') + + return df_results + + +def main(): + args = parse_args() + if not os.path.exists(args.source): + raise ValueError(f'Error: CSV file {args.source} not found') + + df_results = parse_moe_csv(args.source, args.tag) + df_results.to_csv(args.target, index=False) + + +if __name__ == '__main__': + main() From e7ce6a7e54064ae39241f3c97f1fc05369c73dbf Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 11:56:33 +0000 Subject: [PATCH 06/13] Changed path --- .github/workflows/third-party-benchmarks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 08ca333062..ba04c8dad8 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -106,9 +106,9 @@ jobs: ./scripts/test-triton.sh --install-vllm --skip-pip-install --skip-pytorch-install - cd benchmarks - python third_party/vllm/batched_moe_benchmark.py --reports $REPORTS - python third_party/vllm/transform_results.py $REPORTS/moe-gemm.csv $REPORTS/moe-gemm-report.csv --tag $TAG + cd vllm + python ../benchmarks/third_party/vllm/batched_moe_benchmark.py --reports $REPORTS + python ../benchmarks/third_party/vllm/transform_results.py $REPORTS/moe-gemm.csv $REPORTS/moe-gemm-report.csv --tag $TAG - name: Run Liger-Kernel benchmarks From d83bc3b2c6a779a8823cc0153cac5bdd00493c5a Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 13:44:06 +0000 Subject: [PATCH 07/13] tests copy --- .github/workflows/third-party-benchmarks.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index ba04c8dad8..207dd8d646 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -105,10 +105,11 @@ jobs: source ./scripts/capture-hw-details.sh ./scripts/test-triton.sh --install-vllm --skip-pip-install --skip-pytorch-install + cp -r vllm/tests benchmarks/third_party/vllm/tests - cd vllm - python ../benchmarks/third_party/vllm/batched_moe_benchmark.py --reports $REPORTS - python ../benchmarks/third_party/vllm/transform_results.py $REPORTS/moe-gemm.csv $REPORTS/moe-gemm-report.csv --tag $TAG + cd benchmarks/third_party/vllm + python batched_moe_benchmark.py --reports $REPORTS + python transform_results.py $REPORTS/moe-gemm.csv $REPORTS/moe-gemm-report.csv --tag $TAG - name: Run Liger-Kernel benchmarks From 66e65810d99d646a84efc1afade4fad3185e97cc Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 6 Oct 2025 14:19:32 +0000 Subject: [PATCH 08/13] Fixed naming --- .github/workflows/third-party-benchmarks.yml | 2 +- benchmarks/third_party/vllm/batched_moe_benchmark.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 207dd8d646..d7dc887366 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -109,7 +109,7 @@ jobs: cd benchmarks/third_party/vllm python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm.csv $REPORTS/moe-gemm-report.csv --tag $TAG + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG - name: Run Liger-Kernel benchmarks diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py index df2a3b29b9..a5fc160355 100644 --- a/benchmarks/third_party/vllm/batched_moe_benchmark.py +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -487,7 +487,7 @@ def get_batched_mm_benchmark( use_fp8_w8a8: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, - plot_name: str = 'moe-gemm', + plot_name: str = 'moe-gemm-performance', ): """ Returns a Mark object containing a Benchmark object for batched matrix multiplication. @@ -644,6 +644,5 @@ def tflops(ms): dtype=torch.bfloat16, use_fp8_w8a8=(os.getenv('USE_FP8_W8A8', '0') == '1'), per_act_token_quant=(os.getenv('PER_ACT_TOKEN_QUANT', '0') == '1'), - plot_name='moe-batched-mm-performance', ) _benchmark_mm.run(show_plots=False, print_data=True) From 4a0df651e785302b272d10c957d7bc5eefcf29f9 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 7 Oct 2025 15:58:50 +0000 Subject: [PATCH 09/13] Fixed dtype bug --- .../third_party/vllm/batched_moe_benchmark.py | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py index a5fc160355..d9d9f9d96a 100644 --- a/benchmarks/third_party/vllm/batched_moe_benchmark.py +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -250,10 +250,10 @@ def batched_triton_kernel( # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). - stride_ae: tl.constexpr, + stride_ae: tl.int64, stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_be: tl.constexpr, + stride_be: tl.int64, stride_bk: tl.constexpr, stride_bn: tl.constexpr, stride_ce: tl.constexpr, @@ -299,11 +299,8 @@ def batched_triton_kernel( cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) cta_n_size = min(BLOCK_N, N - cta_n_start) - # M = M a_desc = tl.make_tensor_descriptor(base=a_ptr + expert_id * stride_ae, shape=(e_num_tokens, K), strides=(stride_am, stride_ak), block_shape=(BLOCK_M, BLOCK_K)) - # b_desc = tl.make_tensor_descriptor(base=b_ptr + expert_id * stride_be, shape=(N, K), strides=(stride_bn, stride_bk), - # block_shape=(BLOCK_N, BLOCK_K)) b_desc = tl.make_tensor_descriptor(base=b_ptr + expert_id * stride_be, shape=(K, N), strides=(stride_bk, stride_bn), block_shape=(BLOCK_K, BLOCK_N)) c_desc = tl.make_tensor_descriptor(base=c_ptr + expert_id * stride_ce, shape=(e_num_tokens, N), @@ -453,7 +450,7 @@ def invoke_moe_batched_triton_kernel_td(A: torch.Tensor, # [E, max_tokens, K] BATCHED_MM_X_VALS = [(E, M, K, N) for E in [8, 32] for M in [32, 224, 512] for K in [128, 1024] for N in [128, 1024]] BATCHED_MM_X_VALS = [ # (256, 16, 7168, 2048 * 2), - # (256, 16, 7168, 2048), + (256, 256, 7168, 2048), # (256, 16, 7168, 2048), # (256, 2, 5000, 2024), # (256, 16, 2048, 7168), @@ -550,18 +547,6 @@ def benchmark(num_experts, max_tokens_per_expert, K, N, provider): block_shape=block_shape, per_act_token_quant=per_act_token_quant, ) - # A_q[:] = 0 - # A_q[:, :, :] = 1 - - # B_q[:] = 0 - # B_q[:, 0, 0] = 1 - - # A_q[:] = 0 - # A_q[:, 0, 0] = 1 - - # B_q[:] = 0 - # B_q[:, 0, 0] = 0 - quantiles = [0.5, 0.0, 1.0] C = torch.zeros(out_shape, device='xpu', dtype=act_dtype) From 493c31b7a6a3361e85f1e34107a82e77fb22504d Mon Sep 17 00:00:00 2001 From: Egor Date: Fri, 10 Oct 2025 17:49:36 +0200 Subject: [PATCH 10/13] Apply suggestions from code review Co-authored-by: Anatoly Myachev --- .github/workflows/third-party-benchmarks.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index d7dc887366..ff7d23d535 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -92,16 +92,10 @@ jobs: run: | cd benchmarks pip install . - pip install intel-pti==0.12.4 - PTI_LIBS_DIR=$(python -c "import sysconfig; print(sysconfig.get_paths()['stdlib']+'/..')") - # the output should contain: `libpti.so`, `libpti_metrics.so.0.12.4` and `libpti_view.so.0.12.4` - ls $PTI_LIBS_DIR - echo "PTI_LIBS_DIR=$PTI_LIBS_DIR" >> $GITHUB_ENV - name: Run vllm benchmarks if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} run: | - export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH source ./scripts/capture-hw-details.sh ./scripts/test-triton.sh --install-vllm --skip-pip-install --skip-pytorch-install From 345a38930484f0fda7220a34efc9d34c3b6eec06 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Fri, 10 Oct 2025 15:48:23 +0000 Subject: [PATCH 11/13] Added configs from real models --- .../third_party/vllm/batched_moe_benchmark.py | 157 +++++++++--------- 1 file changed, 82 insertions(+), 75 deletions(-) diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py index d9d9f9d96a..b12d95b2f6 100644 --- a/benchmarks/third_party/vllm/batched_moe_benchmark.py +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -10,8 +10,7 @@ batched MoE implementations using vLLM kernels. """ -from typing import Optional -import os +from typing import Optional, List import torch import triton @@ -199,38 +198,34 @@ def expert_triton_kernel( # tl.store(c_ptrs, accumulator, mask=c_mask) -# def get_matmul_batched_autotune_configs(): -# configs = [ -# triton.Config( -# {'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, -# num_stages=s, num_warps=32) for s in [2, 3] -# ] + [ -# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, -# num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)]) -# ] + [ -# triton.Config( -# {'BLOCK_M': 128, 'BLOCK_N': 1024, 'BLOCK_K': 16, 'grf_mode': 'large'}, -# num_stages=s, num_warps=32) for s in [2, 3] -# ] + [ -# triton.Config( -# {'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, -# num_stages=s, num_warps=32) for s in [2] -# ] + [ -# triton.Config( -# {'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, -# num_stages=s, num_warps=32) for s in [2] -# ] + [ -# triton.Config( -# {'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, -# num_stages=s, num_warps=4) for s in [2] -# ] -# return configs - - -# @triton.autotune( -# configs=get_matmul_batched_autotune_configs(), -# key=['max_num_tokens', 'K', 'N'] -# ) +def get_matmul_batched_autotune_configs() -> List[triton.Config]: + configs = [ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2, 3] + ] + [ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w) + for s in [2] + for (m, w) in ([('large', 32), ('small', 64)]) + ] + [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 1024, 'BLOCK_K': 16, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2, 3] + ] + [ + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2] + ] + [ + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2] + ] + [ + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=4) + for s in [2] + ] + return configs + + +@triton.autotune( + configs=get_matmul_batched_autotune_configs(), + key=['max_num_tokens', 'K', 'N'], +) @triton.jit def batched_triton_kernel( a_ptr, # [E, max_num_tokens, K] @@ -356,15 +351,12 @@ def invoke_moe_batched_triton_kernel_td(A: torch.Tensor, # [E, max_tokens, K] BLOCK_M = config['BLOCK_SIZE_M'] BLOCK_N = config['BLOCK_SIZE_N'] - BLOCK_K = config['BLOCK_SIZE_K'] - BLOCK_M = 256 - BLOCK_N = 128 - BLOCK_K = 32 - num_warps = 64 - # BLOCK_M = 16 - # BLOCK_N = 16 - # BLOCK_K = 16 - # num_warps = 4 + # BLOCK_K = config['BLOCK_SIZE_K'] + # Looks like generally good parameters + # BLOCK_M = 256 + # BLOCK_N = 128 + # BLOCK_K = 32 + # num_warps = 64 grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) @@ -439,31 +431,43 @@ def invoke_moe_batched_triton_kernel_td(A: torch.Tensor, # [E, max_tokens, K] use_int8_w8a16, per_act_token_quant, # Kernel config - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K, - num_warps=num_warps, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # BLOCK_K=BLOCK_K, + # num_warps=num_warps, ) -# Benchmark shapes for batched MoE (E: num_experts, M: max_tokens_per_expert, K: hidden_dim, N: intermediate_dim) -BATCHED_MM_X_VALS = [(E, M, K, N) for E in [8, 32] for M in [32, 224, 512] for K in [128, 1024] for N in [128, 1024]] -BATCHED_MM_X_VALS = [ - # (256, 16, 7168, 2048 * 2), - (256, 256, 7168, 2048), - # (256, 16, 7168, 2048), - # (256, 2, 5000, 2024), - # (256, 16, 2048, 7168), - *[(E, M, K, N) for E in [8, 32] for M in [32, 224, 512] for K in [128, 1024] for N in [128, 1024]] -] +# Benchmark shapes for batched MoE +# (E: num_experts, M: max_tokens_per_expert, K: hidden_dim, N: intermediate_dim, fp8, block_quant) +# BATCHED_MM_X_VALS = [(E, M, K, N, False, False) for E in [8, 32] for M in [32, 224, 512] for K in [128, 1024] for N in [128, 1024]] +BATCHED_MM_X_VALS = sum([[(E, M, hidden_size, int_size * 2, fp8, block_quant), + (E, M, int_size, hidden_size, fp8, block_quant)] + for M in [1, 8, 64, 256] + for E, hidden_size, int_size, fp8, block_quant in [ + # deepseek V3, fp8 block quant + # (256, 7168, 2048, True, True), + (256, 7168, 2048, False, False), + # llama4 scout bf16 + (16, 5120, 8192, False, False), + # gpt-oss 20b mxfp4 + (32, 2880, 2880, False, False), + # gpt-oss 120b mxfp4 + (128, 2880, 2880, False, False), + # qwen3-235b-A22B bf16/fp8 + (128, 4096, 1536, False, False), + # qwen3-30b-A3B bf16/fp8 + (128, 2048, 768, False, False), + # qwen3-next-80B bf16 + (512, 2048, 512, False, False), + ]], []) DEVICE_NAME = torch.xpu.get_device_name() DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory def is_enough_memory(x_val): - # x_val: (E, M, K, N) - E, M, K, N = x_val + E, M, K, N, fp8, block_quant = x_val # A: (E, M, K) bfloat16 # B: (E, K, N) bfloat16 # C: (E, M, N) float32 @@ -480,10 +484,7 @@ def is_enough_memory(x_val): def get_batched_mm_benchmark( providers_filter: Optional[list[str]] = None, - dtype: torch.dtype = torch.bfloat16, - use_fp8_w8a8: bool = False, per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, plot_name: str = 'moe-gemm-performance', ): """ @@ -498,16 +499,10 @@ def get_batched_mm_benchmark( providers = benchmark_suite.filter_providers(supported_providers, providers_filter) # Set up quantization - if use_fp8_w8a8: - act_dtype = torch.bfloat16 - quant_dtype = torch.float8_e4m3fn - else: - act_dtype = dtype - quant_dtype = None @benchmark_suite.perf_report( benchmark_suite.Benchmark( - x_names=['num_experts', 'max_tokens_per_expert', 'K', 'N'], + x_names=['num_experts', 'max_tokens_per_expert', 'K', 'N', 'fp8', 'block_quant'], x_vals=BATCHED_MM_X_VALS, line_arg='provider', line_vals=list(providers.keys()), @@ -517,10 +512,26 @@ def get_batched_mm_benchmark( plot_name=plot_name, args={}, )) - def benchmark(num_experts, max_tokens_per_expert, K, N, provider): + def benchmark(num_experts, max_tokens_per_expert, K, N, fp8, block_quant, provider): current_platform.seed_everything(70) n_warmup = 300 + # print(num_experts, max_tokens_per_expert, K, N, fp8, block_quant, provider, fp8, block_quant) + if fp8: + use_fp8_w8a8 = True + act_dtype = torch.bfloat16 + quant_dtype = torch.float8_e4m3fn + else: + use_fp8_w8a8 = False + act_dtype = torch.bfloat16 + quant_dtype = None + + dtype = torch.bfloat16 + if block_quant: + block_shape = (128, 128) + else: + block_shape = None + # Create random number of expert tokens num_expert_tokens = torch.randint(low=0, high=max_tokens_per_expert + 1, size=(num_experts, ), device='xpu', dtype=torch.int32) @@ -625,9 +636,5 @@ def tflops(ms): if __name__ == '__main__': # Run batched MM benchmark print('Running batched MM benchmark...') - _benchmark_mm = get_batched_mm_benchmark( - dtype=torch.bfloat16, - use_fp8_w8a8=(os.getenv('USE_FP8_W8A8', '0') == '1'), - per_act_token_quant=(os.getenv('PER_ACT_TOKEN_QUANT', '0') == '1'), - ) + _benchmark_mm = get_batched_mm_benchmark() _benchmark_mm.run(show_plots=False, print_data=True) From 2d35f2bd2bb717c94d17ef17106b130dbd7e05b7 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 13 Oct 2025 11:56:16 +0000 Subject: [PATCH 12/13] Fixed script --- scripts/test-triton.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 94e3ea675c..0bdc5de7ad 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -637,7 +637,7 @@ run_vllm_install() { git clone https://github.com/vllm-project/vllm.git cd vllm git checkout "$(<../benchmarks/third_party/vllm/vllm-pin.txt)" - git apply $TRITON_PROJ/benchmarks/third_party/vllm/vllm-fix.patch + git apply ../benchmarks/third_party/vllm/vllm-fix.patch cd .. fi @@ -652,7 +652,7 @@ run_vllm_install() { VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -e . cd .. - VLLM_TARGET_DEVICE=xpu pip install --no-deps --no-build-isolation vllm + VLLM_TARGET_DEVICE=xpu pip install --no-deps --no-build-isolation -e vllm fi pip install pytest pytest-cov pytest-xdist cachetools cbor2 blake3 pybase64 openai_harmony tblib From 05969440e435d3ecfe6ea27549a1887137fa7b75 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 13 Oct 2025 13:02:47 +0000 Subject: [PATCH 13/13] Enabled fp8 cases --- .github/workflows/third-party-benchmarks.yml | 12 +- .../third_party/vllm/batched_moe_benchmark.py | 105 +++++++++--------- .../third_party/vllm/transform_results.py | 10 +- 3 files changed, 69 insertions(+), 58 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index ff7d23d535..982f76dbc6 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -93,7 +93,7 @@ jobs: cd benchmarks pip install . - - name: Run vllm benchmarks + - name: Run vllm benchmarks bf16 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} run: | source ./scripts/capture-hw-details.sh @@ -103,8 +103,16 @@ jobs: cd benchmarks/third_party/vllm python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark + - name: Run vllm benchmarks fp8 + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} + run: | + source ./scripts/capture-hw-details.sh + + cd benchmarks/third_party/vllm + FP8="1" python batched_moe_benchmark.py --reports $REPORTS + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-fp8-benchmark - name: Run Liger-Kernel benchmarks if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'liger-kernel')) }} diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py index b12d95b2f6..593a9a0626 100644 --- a/benchmarks/third_party/vllm/batched_moe_benchmark.py +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -10,6 +10,7 @@ batched MoE implementations using vLLM kernels. """ +import os from typing import Optional, List import torch @@ -206,9 +207,6 @@ def get_matmul_batched_autotune_configs() -> List[triton.Config]: triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)]) - ] + [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 1024, 'BLOCK_K': 16, 'grf_mode': 'large'}, num_stages=s, num_warps=32) - for s in [2, 3] ] + [ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) for s in [2] @@ -441,26 +439,37 @@ def invoke_moe_batched_triton_kernel_td(A: torch.Tensor, # [E, max_tokens, K] # Benchmark shapes for batched MoE # (E: num_experts, M: max_tokens_per_expert, K: hidden_dim, N: intermediate_dim, fp8, block_quant) # BATCHED_MM_X_VALS = [(E, M, K, N, False, False) for E in [8, 32] for M in [32, 224, 512] for K in [128, 1024] for N in [128, 1024]] -BATCHED_MM_X_VALS = sum([[(E, M, hidden_size, int_size * 2, fp8, block_quant), - (E, M, int_size, hidden_size, fp8, block_quant)] - for M in [1, 8, 64, 256] - for E, hidden_size, int_size, fp8, block_quant in [ - # deepseek V3, fp8 block quant - # (256, 7168, 2048, True, True), - (256, 7168, 2048, False, False), - # llama4 scout bf16 - (16, 5120, 8192, False, False), - # gpt-oss 20b mxfp4 - (32, 2880, 2880, False, False), - # gpt-oss 120b mxfp4 - (128, 2880, 2880, False, False), - # qwen3-235b-A22B bf16/fp8 - (128, 4096, 1536, False, False), - # qwen3-30b-A3B bf16/fp8 - (128, 2048, 768, False, False), - # qwen3-next-80B bf16 - (512, 2048, 512, False, False), - ]], []) +# Each pair represent transformation for SwiGLU and then output transformation +MM_CONFIGS_BF16 = sum([[(E, M, hidden_size, int_size * 2, fp8, block_quant), # input -> swiglu input + (E, M, int_size, hidden_size, fp8, block_quant)] # swiglu output -> final output + for M in [1, 8, 64, 256] + for E, hidden_size, int_size, fp8, block_quant in [ + # llama4 scout bf16 + (16, 5120, 8192, False, False), + # gpt-oss 20b mxfp4 + (32, 2880, 2880, False, False), + # gpt-oss 120b mxfp4 + (128, 2880, 2880, False, False), + # qwen3-235b-A22B bf16/fp8 + (128, 4096, 1536, False, False), + # qwen3-30b-A3B bf16/fp8 + (128, 2048, 768, False, False), + # qwen3-next-80B bf16 + (512, 2048, 512, False, False), + ]], []) + +MM_CONFIGS_FP8 = sum([[(E, M, hidden_size, int_size * 2, fp8, block_quant), + (E, M, int_size, hidden_size, fp8, block_quant)] + for M in [1, 8, 64, 256] + for E, hidden_size, int_size, fp8, block_quant in [ + # deepseek V3, fp8 block quant + # 3.5 GBs of weights! + (256, 7168, 2048, True, True), + # # qwen3-235b-A22B bf16/fp8 + (128, 4096, 1536, False, True), + # qwen3-30b-A3B bf16/fp8 + (128, 2048, 768, False, True), + ]], []) DEVICE_NAME = torch.xpu.get_device_name() DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory @@ -468,23 +477,25 @@ def invoke_moe_batched_triton_kernel_td(A: torch.Tensor, # [E, max_tokens, K] def is_enough_memory(x_val): E, M, K, N, fp8, block_quant = x_val - # A: (E, M, K) bfloat16 - # B: (E, K, N) bfloat16 - # C: (E, M, N) float32 + # A: (E, M, K) bfloat16 or fp8 + # B: (E, K, N) bfloat16 or fp8 + # C: (E, M, N) bfloat16 # num_expert_tokens: (E,) int32 - required_memory = E * M * K * 2 + E * K * N * 2 + E * M * N * 4 + E * 4 + n_bytes = 1 if fp8 else 2 + required_memory = E * M * K * n_bytes + E * K * N * n_bytes + E * M * N * 2 + E * 4 enough_memory = required_memory < DEVICE_TOTAL_MEMORY if not enough_memory: print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}") return enough_memory -BATCHED_MM_X_VALS = [x_val for x_val in BATCHED_MM_X_VALS if is_enough_memory(x_val)] +MM_CONFIGS_BF16 = [x_val for x_val in MM_CONFIGS_BF16 if is_enough_memory(x_val)] def get_batched_mm_benchmark( providers_filter: Optional[list[str]] = None, per_act_token_quant: bool = False, + fp8=False, plot_name: str = 'moe-gemm-performance', ): """ @@ -495,15 +506,17 @@ def get_batched_mm_benchmark( 'triton-td': 'triton-td', 'pytorch': 'pytorch', } + if fp8: + # pytorch is very slow with fp8 case, for (8, 64, 1024, 2048) case it has ~0.15 TFlops vs 1.5 for triton + del supported_providers['pytorch'] providers = benchmark_suite.filter_providers(supported_providers, providers_filter) - - # Set up quantization + configs = MM_CONFIGS_FP8 if fp8 else MM_CONFIGS_BF16 @benchmark_suite.perf_report( benchmark_suite.Benchmark( x_names=['num_experts', 'max_tokens_per_expert', 'K', 'N', 'fp8', 'block_quant'], - x_vals=BATCHED_MM_X_VALS, + x_vals=configs, line_arg='provider', line_vals=list(providers.keys()), line_names=list(providers.values()), @@ -514,23 +527,13 @@ def get_batched_mm_benchmark( )) def benchmark(num_experts, max_tokens_per_expert, K, N, fp8, block_quant, provider): current_platform.seed_everything(70) - n_warmup = 300 + n_warmup = 600 - # print(num_experts, max_tokens_per_expert, K, N, fp8, block_quant, provider, fp8, block_quant) - if fp8: - use_fp8_w8a8 = True - act_dtype = torch.bfloat16 - quant_dtype = torch.float8_e4m3fn - else: - use_fp8_w8a8 = False - act_dtype = torch.bfloat16 - quant_dtype = None + act_dtype = torch.bfloat16 + use_fp8_w8a8 = fp8 + quant_dtype = torch.float8_e4m3fn if fp8 else None - dtype = torch.bfloat16 - if block_quant: - block_shape = (128, 128) - else: - block_shape = None + block_shape = (128, 128) if block_quant else None # Create random number of expert tokens num_expert_tokens = torch.randint(low=0, high=max_tokens_per_expert + 1, size=(num_experts, ), device='xpu', @@ -584,7 +587,6 @@ def torch_fn(): # Triton batched MoE kernel invoke_kernel = invoke_moe_batched_triton_kernel if provider == 'triton' else invoke_moe_batched_triton_kernel_td - # invoke_kernel = invoke_moe_batched_triton_kernel_td def triton_fn(): invoke_kernel( A_q, @@ -598,7 +600,8 @@ def triton_fn(): use_fp8_w8a8, False, False, - config={'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 16 if dtype.itemsize > 1 else 32}, + config={'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32 if fp8 else 16}, + # config={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32 if dtype.itemsize > 1 else 32}, per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) @@ -621,7 +624,9 @@ def triton_fn(): # Compute: E * M * N * K * 2 FLOPs (multiply-add) def gbps(ms): - total_bytes = num_experts * (max_tokens_per_expert * K * 2 + K * N * 2 + max_tokens_per_expert * N * 4) + n_bytes = 1 if fp8 else 2 + total_bytes = num_experts * (max_tokens_per_expert * K * n_bytes + K * N * n_bytes + + max_tokens_per_expert * N * 2) return total_bytes * (1e-9) / (ms * 1e-3) def tflops(ms): @@ -636,5 +641,5 @@ def tflops(ms): if __name__ == '__main__': # Run batched MM benchmark print('Running batched MM benchmark...') - _benchmark_mm = get_batched_mm_benchmark() + _benchmark_mm = get_batched_mm_benchmark(fp8=(os.getenv('FP8', '0') == '1'), ) _benchmark_mm.run(show_plots=False, print_data=True) diff --git a/benchmarks/third_party/vllm/transform_results.py b/benchmarks/third_party/vllm/transform_results.py index 8765768ef5..66ada1157c 100644 --- a/benchmarks/third_party/vllm/transform_results.py +++ b/benchmarks/third_party/vllm/transform_results.py @@ -12,14 +12,12 @@ def parse_args(): parser.add_argument('source', help='Path to the MoE benchmark CSV file') parser.add_argument('target', help='Path to output CSV file') parser.add_argument('--tag', help='Tag for the benchmark run', default='') - parser.add_argument('--model', help='Model name', default='unknown-model') - parser.add_argument('--max-new-tokens', type=int, help='Maximum new tokens', default=128) - parser.add_argument('--batch-size', type=int, help='Batch size', default=1) + parser.add_argument('--benchmark', help='moe-benchmark', default='') return parser.parse_args() -def parse_moe_csv(csv_file_path, tag): +def parse_moe_csv(csv_file_path, tag, benchmark): """Parse the MoE benchmark CSV and extract performance metrics.""" df = pd.read_csv(csv_file_path) @@ -49,7 +47,7 @@ def parse_moe_csv(csv_file_path, tag): valid_rows['run_uuid'] = run_uuid valid_rows['ts'] = current_datetime valid_rows['benchmark_group'] = 'moe-benchmark' - valid_rows['benchmark'] = 'moe-benchmark' + valid_rows['benchmark'] = benchmark valid_rows['compiler'] = compiler_name valid_rows['value_name'] = 'tflops' valid_rows['value'] = valid_rows[tflops_col].astype(float) @@ -92,7 +90,7 @@ def main(): if not os.path.exists(args.source): raise ValueError(f'Error: CSV file {args.source} not found') - df_results = parse_moe_csv(args.source, args.tag) + df_results = parse_moe_csv(args.source, args.tag, args.benchmark) df_results.to_csv(args.target, index=False)