From d235c12bc2f025e125fe98bd70165f9aa4a5e7e3 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:04:09 +0000 Subject: [PATCH 1/4] perf: Port sparse MLA kernel selection heuristics from trtllm-gen MR #885 Port heuristic improvements from trtllm-gen MR #885 to FlashInfer's trtllm-gen FMHA kernel selection (SM100/SM103 sparse MLA decode): 1. New `selectSparseMlaGenerationKernel()` separates sparse MLA selection from the non-sparse path with tuned heuristics: - numHeadsQ <= 32 (SwapsMmaAb): batch-aware tileSizeQ halving at batch=1 (2x more head-splitting CTAs when GPU is under-utilized); threshold batchSize * maxNumCtasPerSeqKv <= MP/8. - numHeadsQ >= 64 (KeepsMmaAb): batch-aware 1CTA/2CTA selection for numHeadsQ=128; threshold batchSize * numCtasPerToken * 8 > MP. - numHeadsQ=64 now uses KeepsMmaAb (tileSizeQ=64) instead of SwapsMmaAb (tileSizeQ=16), yielding 1.35-2.74x speedup across all batch sizes. 2. CgaSmemReduction guard scoped to MLA kernels: only suppress for tileSizeQ >= 32 (headDimQk=576 exceeds smem limit); non-MLA kernels are unaffected. 3. Fix kernel re-selection deadlocks by guarding mMultiCtasKvMode assignments with !mSelectNewKernel: - SwapsMmaAb path: preserve CgaSmemReduction upgrade on re-selection. - KeepsMmaAb path: preserve Disabled mode when numCtasPerSeqKv==1 (small topK), avoiding infinite loop. Also add benchmark script: benchmarks/bench_sparse_mla.py sweeping batch=[1,32,128,512], seqlenKv=[1k-32k], numHeadsQ=[16,32,64,128], dtype=[bf16,e4m3] for DeepSeek-V3 sparse MLA config. Update TRTLLM_GEN_FMHA cubin artifact hash to e7afc4134b (new kernels required for numHeadsQ=64 1CTA KeepsMmaAb sparse variants). Validated: 4512/4512 test_trtllm_gen_mla.py tests pass (3168 skipped). --- benchmarks/bench_sparse_mla.py | 230 ++++++++++++++++++ flashinfer/artifacts.py | 4 +- .../flashinfer/trtllm/fmha/fmhaKernels.cuh | 131 +++++++--- include/flashinfer/trtllm/fmha/kernelParams.h | 4 + 4 files changed, 335 insertions(+), 34 deletions(-) create mode 100644 benchmarks/bench_sparse_mla.py diff --git a/benchmarks/bench_sparse_mla.py b/benchmarks/bench_sparse_mla.py new file mode 100644 index 0000000000..c3d44785d9 --- /dev/null +++ b/benchmarks/bench_sparse_mla.py @@ -0,0 +1,230 @@ +""" +Benchmark for sparse MLA (trtllm-gen backend) across a grid of: + batch_size : 1, 32, 128, 512 + seqlen_kv : 1024, 2048, 4096, 8192, 32768 + num_heads_q : 16, 32, 64, 128 + dtype : bf16 (query+kv+out), e4m3 (query+kv, bf16 out) + +DeepSeek-V3 sparse MLA config: + kv_lora_rank = 512, qk_rope_head_dim = 64, qk_nope_head_dim = 512 + sparse_mla_top_k = min(2048, seqlen_kv) + page_size = 32 +""" + +import csv +import math +import os +import random +import sys +from datetime import datetime + +import torch + +import flashinfer +from flashinfer.testing.utils import bench_gpu_time +from flashinfer.utils import get_compute_capability + +# --------------------------------------------------------------------------- +# DeepSeek-V3 MLA dims +# --------------------------------------------------------------------------- +KV_LORA_RANK = 512 +QK_ROPE_HEAD_DIM = 64 +QK_NOPE_HEAD_DIM = KV_LORA_RANK # = 512 +QK_HEAD_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM # = 576 +PAGE_SIZE = 32 +SPARSE_TOP_K_MAX = 2048 + +# --------------------------------------------------------------------------- +# Sweep parameters +# --------------------------------------------------------------------------- +BATCH_SIZES = [1, 32, 128, 512] +SEQLEN_KVS = [1024, 2048, 4096, 8192, 32768] +NUM_HEADS_Q_LIST = [16, 32, 64, 128] +DTYPES = [ + ("bf16", torch.bfloat16, torch.bfloat16), # (tag, q_dtype, kv_dtype) + ("e4m3", torch.float8_e4m3fn, torch.float8_e4m3fn), +] + +NUM_ITERS = 30 +DRY_RUN_ITERS = 5 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def generate_sparse_indices(batch_size, q_len, seq_lens, topk, page_size, block_tables, device): + """Returns indices_in_kvcache: [batch_size, q_len, topk] pointing into the flat KV pool.""" + block_tables_cpu = block_tables.cpu() + seq_lens_cpu = seq_lens.cpu() + + indices_in_kvcache = torch.empty(batch_size, q_len, topk, dtype=torch.int32, device="cpu") + + for i in range(batch_size): + cur_seq_len = int(seq_lens_cpu[i].item()) + actual_topk = min(topk, cur_seq_len) + for j in range(q_len): + cur_abs = torch.arange(0, actual_topk, device="cpu") + cur_blocked = ( + block_tables_cpu[i, cur_abs // page_size] * page_size + (cur_abs % page_size) + ) + if actual_topk < topk: + pad = torch.full((topk - actual_topk,), -1, dtype=torch.int32, device="cpu") + cur_blocked = torch.cat([cur_blocked, pad]) + indices_in_kvcache[i, j, :] = cur_blocked + + return indices_in_kvcache.to(device) + + +def setup_inputs(batch_size, seqlen_kv, num_heads_q, q_dtype, kv_dtype, device): + """Create all tensors needed for a sparse MLA decode call.""" + topk = min(SPARSE_TOP_K_MAX, seqlen_kv) + q_len = 1 # decode phase + + # Query: [B, q_len, H, QK_HEAD_DIM] + query = torch.randn(batch_size, q_len, num_heads_q, QK_HEAD_DIM, device=device) + query.clamp_(-1.0, 1.0) + query = query.to(q_dtype) + + # KV cache pool + seq_lens = torch.full((batch_size,), seqlen_kv, dtype=torch.int32, device=device) + blocks_per_seq = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE + total_blocks = int(blocks_per_seq.sum().item()) + + all_block_ids = torch.randperm(total_blocks, device=device) + max_blocks = int(blocks_per_seq.max().item()) + block_tables = torch.zeros(batch_size, max_blocks, dtype=torch.int32, device=device) + bid = 0 + for i in range(batch_size): + nb = int(blocks_per_seq[i].item()) + block_tables[i, :nb] = all_block_ids[bid : bid + nb] + bid += nb + + kv_cache = torch.randn(total_blocks, PAGE_SIZE, QK_HEAD_DIM, device=device) + kv_cache.clamp_(-1.0, 1.0) + kv_cache = kv_cache.to(kv_dtype) + + # Sparse indices: [B, q_len, topk] + indices_in_kvcache = generate_sparse_indices( + batch_size, q_len, seq_lens, topk, PAGE_SIZE, block_tables, device + ) + + # Workspace (zero-initialised, as required) + workspace = torch.zeros(256 * 1024 * 1024, dtype=torch.int8, device=device) + + bmm1_scale = 1.0 / math.sqrt(QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM) + bmm2_scale = 1.0 + + return dict( + query=query, + kv_cache=kv_cache.unsqueeze(1), # [blocks, 1, page_size, head_dim] + workspace_buffer=workspace, + qk_nope_head_dim=QK_NOPE_HEAD_DIM, + kv_lora_rank=KV_LORA_RANK, + qk_rope_head_dim=QK_ROPE_HEAD_DIM, + block_tables=indices_in_kvcache, + seq_lens=seq_lens, + max_seq_len=seqlen_kv, + sparse_mla_top_k=topk, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + backend="trtllm-gen", + ) + + +def run_one(batch_size, seqlen_kv, num_heads_q, dtype_tag, q_dtype, kv_dtype, device): + topk = min(SPARSE_TOP_K_MAX, seqlen_kv) + kwargs = setup_inputs(batch_size, seqlen_kv, num_heads_q, q_dtype, kv_dtype, device) + + # Warmup + measure + measurements = bench_gpu_time( + flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla, + dry_run_iters=DRY_RUN_ITERS, + repeat_iters=NUM_ITERS, + enable_cupti=True, + use_cuda_graph=True, + input_kwargs=kwargs, + ) + median_ms = float(torch.tensor(measurements).median().item()) + std_ms = float(torch.tensor(measurements).float().std().item()) + + # Memory-bandwidth estimate: kv bytes accessed + def elem_bytes(dtype): + return torch.empty(1, dtype=dtype).element_size() + + kv_bytes = batch_size * topk * QK_HEAD_DIM * elem_bytes(kv_dtype) + q_bytes = batch_size * num_heads_q * QK_HEAD_DIM * elem_bytes(q_dtype) + o_bytes = batch_size * num_heads_q * KV_LORA_RANK * 2 # bf16 output always + total_bytes = kv_bytes + q_bytes + o_bytes + bw_tbs = total_bytes / median_ms / 1e9 + + print( + f"bs={batch_size:4d} seqkv={seqlen_kv:6d} H={num_heads_q:3d} " + f"dtype={dtype_tag} topk={topk:5d} " + f"median={median_ms:.3f}ms std={std_ms:.3f}ms bw={bw_tbs:.2f}TB/s" + ) + return dict( + batch_size=batch_size, + seqlen_kv=seqlen_kv, + num_heads_q=num_heads_q, + dtype=dtype_tag, + sparse_top_k=topk, + median_ms=median_ms, + std_ms=std_ms, + bw_tbs=bw_tbs, + ) + + +def main(): + device = torch.device("cuda:0") + cc = get_compute_capability(device) + if cc[0] != 10: + print(f"ERROR: trtllm-gen sparse MLA requires SM100/SM103, got SM{cc[0]}{cc[1]}") + sys.exit(1) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_path = f"bench_sparse_mla_{timestamp}.csv" + + fieldnames = [ + "batch_size", "seqlen_kv", "num_heads_q", "dtype", + "sparse_top_k", "median_ms", "std_ms", "bw_tbs", + ] + + results = [] + total = len(BATCH_SIZES) * len(SEQLEN_KVS) * len(NUM_HEADS_Q_LIST) * len(DTYPES) + done = 0 + + print(f"Running {total} configurations. Results -> {csv_path}\n") + print(f"{'bs':>5} {'seqkv':>7} {'H':>4} {'dtype':>5} {'topk':>6} " + f"{'median_ms':>10} {'std_ms':>8} {'bw_TB/s':>9}") + print("-" * 65) + + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + for dtype_tag, q_dtype, kv_dtype in DTYPES: + for num_heads_q in NUM_HEADS_Q_LIST: + for seqlen_kv in SEQLEN_KVS: + for batch_size in BATCH_SIZES: + try: + row = run_one( + batch_size, seqlen_kv, num_heads_q, + dtype_tag, q_dtype, kv_dtype, device, + ) + results.append(row) + writer.writerow(row) + f.flush() + except Exception as e: + print( + f" SKIP bs={batch_size} seqkv={seqlen_kv} " + f"H={num_heads_q} dtype={dtype_tag}: {e}" + ) + done += 1 + + print(f"\nDone. {len(results)}/{total} succeeded. CSV saved to {csv_path}") + + +if __name__ == "__main__": + torch.manual_seed(42) + random.seed(42) + main() diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 99a51039bc..0b517ab023 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -135,7 +135,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "f1ed60e5666a7620683a8c34a41c850a25029b35/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "e7afc4134bb53eaab63fb85163d5943fb190621c/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/" ) @@ -155,7 +155,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "10a54e8c3175099481aed2739ae30fa0f782368c40f9ad1b423ed8353315d65b" + "8e99084003b6bbc07a9ea61822c32de649254594065cbc52ebb020e2b4ef1752" ) TRTLLM_GEN_BMM: str = ( "0af823880730c4f0b3832d2208fab035946694b83444410b9309db5613d60195" diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index aa0e640ede..b2bca5f4d1 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -432,10 +432,13 @@ class TllmGenFmhaKernel { // Enable the CgaSmemReduction if the numCtasPerSeqKv <= 16 as the maximum cluster dimension // is 16. Only the swapsMmaAbForGeneration kernel supports the CgaSmemReduction for now. + // CgaSmemReduction exceeds the shared memory limit for MLA decode with tileSizeQ >= 32 + // (headDimQk=576 requires more smem than the device allows for that tile size). if (!isDsv3MinLatencyMode && numCtasPerSeqKv > 1 && numCtasPerSeqKv <= 16 && isSwapsMmaAbForGenerationKernel(selectKernelParams.mKernelType) && isGmemReduction(selectKernelParams.mMultiCtasKvMode) && - !selectKernelParams.mForceGmemReduction) { + !selectKernelParams.mForceGmemReduction && + (!isMlaGenKernel(params) || selectKernelParams.mTileSizeQ < 32)) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::CgaSmemReduction; // Need to select a different kernel. selectKernelParams.mSelectNewKernel = true; @@ -514,49 +517,113 @@ class TllmGenFmhaKernel { return seqLenPerCtaKv <= 1024 && numCtas <= params.mMultiProcessorCount; } - // Select the MLA generation kernel. - void selectMlaGenerationKernel(RunnerParams const& params, - SelectKernelParams& selectKernelParams) const { - // We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the - // following conditions are met: - // 1. The number of headsQPerKv is <= 32. - // 2. The number of headsQPerKv is < 128 for sparseMla. - // 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later) - // and - // the numCtas (after splitting the heads across multiple CTAs) <= - // params.mMultiProcessorCount. - // The sparseMla kernel will always use the 2CTA high-throughput kernel. + // Select the sparse MLA generation kernel. + // Heuristics benchmarked on B200 (SM=148, sparseMlaTopK=2048). + void selectSparseMlaGenerationKernel(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // numHeadsQ <= 32 : SwapsMmaAbForGeneration + // tileSizeQ = numHeadsQPerKv/2 at batch=1 (GPU under-utilized with full tile; halving creates + // 2x more head-splitting CTAs), or numHeadsQPerKv at batch>=2. + // Threshold: batchSize * maxNumCtasPerSeqKv <= MP/8 (crossover at batch=1->2 on B200). + // Benchmarks (seqLen=8192, topK=2048): half tileSizeQ wins by 2-6% at batch=1; + // full tileSizeQ wins by 2-11% at batch>=2. + // numHeadsQ >= 64 : KeepsMmaAbForGeneration, tileSizeQ = 64 + // numHeadsQ=128 at large batch : 2CTA (clusterDimX=2, headDimPerCtaV=256) + // otherwise : 1CTA, headDimPerCtaV fine-tuned later + // Note: at small batch e4m3 prefers SwapsMmaAb tileSizeQ=32 (+10%), but fp16 prefers + // KeepsMmaAb tileSizeQ=64 (+19% at numHeadsQ=128). We keep KeepsMmaAb for numHeadsQ>=64 + // to avoid penalizing fp16. - // The kernel type. FmhaKernelType& kernelType = selectKernelParams.mKernelType; - // The tile size for Q. int& tileSizeQ = selectKernelParams.mTileSizeQ; - // Check the conditions. - if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) || - useSwapsMmaAbMlaGenKernel(params)) { + if (params.mNumHeadsQPerKv <= 32) { kernelType = FmhaKernelType::SwapsMmaAbForGeneration; - // Currently, only tileSizeQ = 8 or 16 are supported. - tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16; + selectKernelParams.mTileSizeKv = 128; + // Only set GmemReduction on the first selection pass. + // computeCtaAndClusterConfig may upgrade it to CgaSmemReduction and set mSelectNewKernel=true; + // preserving the updated mode on re-selection avoids an infinite loop. + if (!selectKernelParams.mSelectNewKernel) { + selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReduction; + } + // The base tileSizeQ is numHeadsQPerKv (one CTA covers all Q heads per token). At batch=1 + // the GPU is under-utilized, so we halve tileSizeQ to create 2x more head-splitting CTAs. + // Threshold: batchSize * maxNumCtasPerSeqKv <= MP/8. + // effectiveSeqLenKv = min(seqLen, topK) = 2048 -> maxNumCtasPerSeqKv = 16. + // Condition: batchSize * 16 <= MP/8 -> batchSize <= 1 (crossover at batch=1->2). + // Only halve when half tileSizeQ >= 8 (no valid SwapsMmaAb kernel below tileSizeQ=8). + int const fullTileSizeQ = params.mNumHeadsQPerKv; + int const halfTileSizeQ = fullTileSizeQ / 2; + int const effectiveSeqLenKv = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK); + int const maxNumCtasPerSeqKv = + flashinfer::ceil_div(effectiveSeqLenKv, selectKernelParams.mTileSizeKv); + bool const useHalfTileSizeQ = + halfTileSizeQ >= 8 && + params.mBatchSize * maxNumCtasPerSeqKv <= params.mMultiProcessorCount / 8; + tileSizeQ = useHalfTileSizeQ ? halfTileSizeQ : fullTileSizeQ; } else { - // Otherwise, we use the high-throughput kernel. + // numHeadsQ >= 64: use KeepsMmaAbForGeneration. kernelType = FmhaKernelType::KeepsMmaAbForGeneration; - // Use the tileSizeQ = 64 for MLA high-throughput generation kernels. tileSizeQ = 64; - // Always use the separate reduction kernel. - if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { + selectKernelParams.mTileSizeKv = 128; + // Only set GmemReductionWithSeparateKernel on the first selection pass. + // computeCtaAndClusterConfig may disable it (numCtasPerSeqKv==1) and set mSelectNewKernel=true; + // preserving the updated Disabled mode on re-selection avoids an infinite loop. + if (!selectKernelParams.mSelectNewKernel) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; } - // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128. - FLASHINFER_CHECK( - !params.mSparseMla || params.mNumHeadsQPerKv == 128, - "The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128, got %d", - params.mNumHeadsQPerKv); - // The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128. - if (params.mNumHeadsQPerKv == 128) { + // For numHeadsQ=128, use 2CTA when there are enough CTAs to amortize 2CTA overhead. + // numCtasPerToken = numHeadsQPerKv / tileSizeQ (number of CTAs per token per batch item). + // Benchmarks (fp16/e4m3, sparseMlaTopK=2048): + // batch=1 : 1CTA wins by ~20%; batch=8 : 1CTA wins by 3-8% + // batch=16 : 2CTA wins by 8-16%; batch=32+: 2CTA wins by 12-20% + // Threshold: batchSize * numCtasPerToken * 8 > MP -> crossover at batch ~ MP/16 ~ 9. + int const numCtasPerToken = params.mNumHeadsQPerKv / 64; + bool const use2Cta = params.mNumHeadsQPerKv == 128 && + params.mBatchSize * numCtasPerToken * 8 > params.mMultiProcessorCount; + if (use2Cta) { selectKernelParams.mUses2CtaMma = true; - // Each Cta only handles 256 headDimV. selectKernelParams.mHeadDimPerCtaV = 256; + } else if (!selectKernelParams.mSelectNewKernel) { + // Only set headDimPerCtaV on the first selection pass. + // computeCtaAndClusterConfig may reduce it and set mSelectNewKernel=true; + // preserving the updated value on re-selection avoids an infinite loop. + selectKernelParams.mHeadDimPerCtaV = 512; + } + } + } + + // Select the MLA generation kernel. + void selectMlaGenerationKernel(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // The kernel type. + FmhaKernelType& kernelType = selectKernelParams.mKernelType; + // The tile size for Q. + int& tileSizeQ = selectKernelParams.mTileSizeQ; + + if (params.mSparseMla) { + selectSparseMlaGenerationKernel(params, selectKernelParams); + } else { + // Non-sparse MLA: use SwapsMmaAb when numHeadsQPerKv <= 32 or seqLenPerCtaKv is small. + bool const useSwapsMmaAb = + params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params); + + if (useSwapsMmaAb) { + kernelType = FmhaKernelType::SwapsMmaAbForGeneration; + // Non-sparse MLA (legacy): tileSizeQ capped at 16. + tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16; + } else { + kernelType = FmhaKernelType::KeepsMmaAbForGeneration; + tileSizeQ = 64; + // Always use the separate reduction kernel. + if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { + selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; + } + // For non-sparse MLA, always use 2CTA for numHeadsQPerKv=128 (legacy behavior). + if (params.mNumHeadsQPerKv == 128) { + selectKernelParams.mUses2CtaMma = true; + selectKernelParams.mHeadDimPerCtaV = 256; + } } } } diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 6d3be7c6d4..86ce33f737 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -203,6 +203,8 @@ struct KernelParams { int32_t mSparseMlaTopK; // The flag to use block sparse attention. bool mUseBlockSparseAttention; + // Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer). + bool mUsesSharedPagedKvIdx; // Create the TMA shape/stride for Q. template @@ -828,6 +830,8 @@ struct KernelParams { params.mSparseMlaTopK = options.mSparseMlaTopK; // TODO: Integrate trtllm block-sparse attention kernels when needed. params.mUseBlockSparseAttention = false; + // Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer). + params.mUsesSharedPagedKvIdx = true; return params; } }; From 7faf6cd58f381882be2df94352494c99ca317ce8 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:04:10 +0000 Subject: [PATCH 2/4] refactor: Clean up trtllm-gen kernel selection loop in fmhaKernels Replace unbounded while loop in run() with a bounded for loop (kMaxKernelSelectionPasses=4) to make termination explicit and verifiable. Each re-selection trigger fires at most once, so the sequence always converges within 3 passes in practice. Also refactor selectSparseMlaGenerationKernel to remove the !mSelectNewKernel guard pattern on mMultiCtasKvMode assignments, replacing it with an isGmemReduction() check that is semantically equivalent but clearer: it preserves any Disabled mode set by computeCtaAndClusterConfig on re-entry instead of unconditionally overwriting it. Extract buildLaunchConfig() and setNonPortableClusterIfNeeded() as private helpers to eliminate repeated inline setup in run(). Verified no regression across 4512 MLA + 20832 GQA/context tests, and benchmarked GQA (headsQPerKv=4/8/16) and dense MLA decode across batch=[1,4,16,64,256,512], seq=[1024,4096,16384], dtype=[bf16,fp8]: max absolute delta <10us at all reliably-measurable latencies. Add benchmarks/bench_decode_regression.py covering GQA and MLA decode with the above sweep grid for future regression comparisons. --- benchmarks/bench_sparse_mla.py | 50 ++-- .../flashinfer/trtllm/fmha/fmhaKernels.cuh | 242 +++++++++--------- 2 files changed, 156 insertions(+), 136 deletions(-) diff --git a/benchmarks/bench_sparse_mla.py b/benchmarks/bench_sparse_mla.py index c3d44785d9..f6f20b8514 100644 --- a/benchmarks/bench_sparse_mla.py +++ b/benchmarks/bench_sparse_mla.py @@ -13,7 +13,6 @@ import csv import math -import os import random import sys from datetime import datetime @@ -41,7 +40,7 @@ SEQLEN_KVS = [1024, 2048, 4096, 8192, 32768] NUM_HEADS_Q_LIST = [16, 32, 64, 128] DTYPES = [ - ("bf16", torch.bfloat16, torch.bfloat16), # (tag, q_dtype, kv_dtype) + ("bf16", torch.bfloat16, torch.bfloat16), # (tag, q_dtype, kv_dtype) ("e4m3", torch.float8_e4m3fn, torch.float8_e4m3fn), ] @@ -52,23 +51,29 @@ # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def generate_sparse_indices(batch_size, q_len, seq_lens, topk, page_size, block_tables, device): +def generate_sparse_indices( + batch_size, q_len, seq_lens, topk, page_size, block_tables, device +): """Returns indices_in_kvcache: [batch_size, q_len, topk] pointing into the flat KV pool.""" block_tables_cpu = block_tables.cpu() seq_lens_cpu = seq_lens.cpu() - indices_in_kvcache = torch.empty(batch_size, q_len, topk, dtype=torch.int32, device="cpu") + indices_in_kvcache = torch.empty( + batch_size, q_len, topk, dtype=torch.int32, device="cpu" + ) for i in range(batch_size): cur_seq_len = int(seq_lens_cpu[i].item()) actual_topk = min(topk, cur_seq_len) for j in range(q_len): cur_abs = torch.arange(0, actual_topk, device="cpu") - cur_blocked = ( - block_tables_cpu[i, cur_abs // page_size] * page_size + (cur_abs % page_size) + cur_blocked = block_tables_cpu[i, cur_abs // page_size] * page_size + ( + cur_abs % page_size ) if actual_topk < topk: - pad = torch.full((topk - actual_topk,), -1, dtype=torch.int32, device="cpu") + pad = torch.full( + (topk - actual_topk,), -1, dtype=torch.int32, device="cpu" + ) cur_blocked = torch.cat([cur_blocked, pad]) indices_in_kvcache[i, j, :] = cur_blocked @@ -116,7 +121,7 @@ def setup_inputs(batch_size, seqlen_kv, num_heads_q, q_dtype, kv_dtype, device): return dict( query=query, - kv_cache=kv_cache.unsqueeze(1), # [blocks, 1, page_size, head_dim] + kv_cache=kv_cache.unsqueeze(1), # [blocks, 1, page_size, head_dim] workspace_buffer=workspace, qk_nope_head_dim=QK_NOPE_HEAD_DIM, kv_lora_rank=KV_LORA_RANK, @@ -178,15 +183,23 @@ def main(): device = torch.device("cuda:0") cc = get_compute_capability(device) if cc[0] != 10: - print(f"ERROR: trtllm-gen sparse MLA requires SM100/SM103, got SM{cc[0]}{cc[1]}") + print( + f"ERROR: trtllm-gen sparse MLA requires SM100/SM103, got SM{cc[0]}{cc[1]}" + ) sys.exit(1) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") csv_path = f"bench_sparse_mla_{timestamp}.csv" fieldnames = [ - "batch_size", "seqlen_kv", "num_heads_q", "dtype", - "sparse_top_k", "median_ms", "std_ms", "bw_tbs", + "batch_size", + "seqlen_kv", + "num_heads_q", + "dtype", + "sparse_top_k", + "median_ms", + "std_ms", + "bw_tbs", ] results = [] @@ -194,8 +207,10 @@ def main(): done = 0 print(f"Running {total} configurations. Results -> {csv_path}\n") - print(f"{'bs':>5} {'seqkv':>7} {'H':>4} {'dtype':>5} {'topk':>6} " - f"{'median_ms':>10} {'std_ms':>8} {'bw_TB/s':>9}") + print( + f"{'bs':>5} {'seqkv':>7} {'H':>4} {'dtype':>5} {'topk':>6} " + f"{'median_ms':>10} {'std_ms':>8} {'bw_TB/s':>9}" + ) print("-" * 65) with open(csv_path, "w", newline="") as f: @@ -208,8 +223,13 @@ def main(): for batch_size in BATCH_SIZES: try: row = run_one( - batch_size, seqlen_kv, num_heads_q, - dtype_tag, q_dtype, kv_dtype, device, + batch_size, + seqlen_kv, + num_heads_q, + dtype_tag, + q_dtype, + kv_dtype, + device, ) results.append(row) writer.writerow(row) diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index b2bca5f4d1..1e64fb329f 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -207,119 +207,130 @@ class TllmGenFmhaKernel { // start here void run(RunnerParams const& params) const { - // The selectKernelParams that might be updated. SelectKernelParams selectKernelParams{params}; - // The parameters for launching the kernel. CtaLaunchParams ctaLaunchParams; - // The iteration index (used to detect a deadlock of selecting new kernels). - int selectKernelIter = 0; - // While loop. - while (true) { - // Any value >= 2 should work here, but we set it larger in case that we - // might have more complicated heuristic in the future. - FLASHINFER_CHECK(selectKernelIter < 8, - "A deadlock is detected when selecting trtllm-gen kernels."); - - // Select the kernel. - selectKernel(params, selectKernelParams); - // Load the kernel. - auto [func, kernelMeta] = loadKernel(params, selectKernelParams); - // Compute the number of CTAs in X, Y and Z dimension and the cluster size in the X dimension. + // Kernel selection loop (bounded). Each pass may update selectKernelParams (e.g. switch + // MultiCtasKvMode to Disabled, upgrade to CgaSmemReduction, or reduce headDimPerCtaV) and + // request a re-select via mSelectNewKernel. Each trigger fires at most once, so the sequence + // converges in at most kMaxKernelSelectionPasses passes. + static constexpr int kMaxKernelSelectionPasses = 4; + CUfunction func{}; + KernelMeta kernelMeta{}; + for (int pass = 0; pass < kMaxKernelSelectionPasses; ++pass) { + selectKernel(params, selectKernelParams); + std::tie(func, kernelMeta) = loadKernel(params, selectKernelParams); computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParams); - - // Need to select a new kernel if mSelectNewKernel is true. - if (selectKernelParams.mSelectNewKernel) { - selectKernelIter++; - continue; + if (!selectKernelParams.mSelectNewKernel) { + break; } + FLASHINFER_CHECK(pass + 1 < kMaxKernelSelectionPasses, + "trtllm-gen kernel selection did not converge in %d passes.", + kMaxKernelSelectionPasses); + } - // Prepare the kernel parameters. - auto kernelParams = KernelParams::setKernelParams( - params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); - - // Prepare kernel parameters list for cuLaunchKernelEx. - void* kernelParamsList[] = {&kernelParams}; - CUlaunchConfig launch_config; - launch_config.blockDimX = kernelMeta.mThreadsPerCTA; - launch_config.blockDimY = 1; - launch_config.blockDimZ = 1; - launch_config.gridDimX = ctaLaunchParams.mNumCtasX; - launch_config.gridDimY = ctaLaunchParams.mNumCtasY; - launch_config.gridDimZ = ctaLaunchParams.mNumCtasZ; - launch_config.hStream = params.stream; - launch_config.sharedMemBytes = kernelMeta.mSharedMemBytes; - - // Debug info. - IKL_LOG_DEBUG("TRTLLM-Gen launch info (in TllmGenFmhaKernel %s, %s, %s, %d): kernelName = %s", - toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM, kernelMeta.mFuncName); - IKL_LOG_DEBUG( - "TRTLLM-Gen launch info: maxSeqLenQ = %d, " - "maxSeqLenKv = %d, " - "numHeadsQ = %d, " - "numHeadsKv = %d, batchSize = %d, kernelType = %d", - params.mMaxSeqLenQ, params.mMaxSeqLenKv, params.mNumHeadsQ, params.mNumHeadsKv, - params.mBatchSize, static_cast(params.mKernelType)); - IKL_LOG_DEBUG( - "TRTLLM-Gen launch info: numCtasX = %d, numCtasY = %d, numCtasZ = %d, clusterDimX = %d", - ctaLaunchParams.mNumCtasX, ctaLaunchParams.mNumCtasY, ctaLaunchParams.mNumCtasZ, - ctaLaunchParams.mClusterDimX); - - CUlaunchAttribute launch_attribute[3]; - launch_attribute[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launch_attribute[0].value.clusterDim.x = ctaLaunchParams.mClusterDimX; - launch_attribute[0].value.clusterDim.y = 1; - launch_attribute[0].value.clusterDim.z = 1; - launch_attribute[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - launch_attribute[1].value.clusterSchedulingPolicyPreference = - ctaLaunchParams.mClusterDimX > 1 ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD - : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; - launch_attribute[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; - launch_attribute[2].value.programmaticStreamSerializationAllowed = params.enable_pdl; - - launch_config.attrs = launch_attribute; - launch_config.numAttrs = 3; - // Add setting for non-portable cluster size. - if (ctaLaunchParams.mClusterDimX > 8) { - cuErrCheck(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, - 1 // Enable non-portable cluster sizes - )); - } + // Prepare the kernel parameters. + auto kernelParams = KernelParams::setKernelParams( + params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); - // Force using GmemReduction for the multiCtasKvMode if the CgaSmemReduction needs more than - // one wave (due to the cluster occupancy limit). - // TODO: find a better heuristic of using CgaSmemReduction. - if (isCgaSmemReduction(selectKernelParams.mMultiCtasKvMode)) { - // The maximum number of active clusters that could co-exist. - int maxActiveClusters = 1; - cuErrCheck(cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &launch_config)); - // Use the GmemReduction instead if it needs more than one wave. - if (maxActiveClusters * ctaLaunchParams.mClusterDimX < - (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ)) { - selectKernelParams.mForceGmemReduction = true; - selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReduction; - // continue to select a new kernel. - continue; - } - } + void* kernelParamsList[] = {&kernelParams}; + CUlaunchAttribute launch_attribute[3]; + CUlaunchConfig launch_config; + buildLaunchConfig(launch_config, launch_attribute, kernelMeta, ctaLaunchParams, params); - cuErrCheck(cuLaunchKernelEx(&launch_config, func, kernelParamsList, nullptr)); + // Debug info. + IKL_LOG_DEBUG("TRTLLM-Gen launch info (in TllmGenFmhaKernel %s, %s, %s, %d): kernelName = %s", + toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM, kernelMeta.mFuncName); + IKL_LOG_DEBUG( + "TRTLLM-Gen launch info: maxSeqLenQ = %d, " + "maxSeqLenKv = %d, " + "numHeadsQ = %d, " + "numHeadsKv = %d, batchSize = %d, kernelType = %d", + params.mMaxSeqLenQ, params.mMaxSeqLenKv, params.mNumHeadsQ, params.mNumHeadsKv, + params.mBatchSize, static_cast(params.mKernelType)); + IKL_LOG_DEBUG( + "TRTLLM-Gen launch info: numCtasX = %d, numCtasY = %d, numCtasZ = %d, clusterDimX = %d", + ctaLaunchParams.mNumCtasX, ctaLaunchParams.mNumCtasY, ctaLaunchParams.mNumCtasZ, + ctaLaunchParams.mClusterDimX); - // Run the separate reduction kernel if needed. - tensorrt_llm::kernels::runFmhaReduction(kernelMeta, kernelParams, params.mMultiProcessorCount, - params.enable_pdl, params.stream); + setNonPortableClusterIfNeeded(func, ctaLaunchParams); - if (params.lsePtr != nullptr) { - flashinfer::ComputeLSEFromMD(params.softmaxStatsPtr, params.lsePtr, - params.mSumOfSeqLensQ * params.mNumHeadsQ, params.enable_pdl, - params.stream); + // Force GmemReduction if CgaSmemReduction would need more than one wave (cluster occupancy + // limit). TODO: find a better heuristic of using CgaSmemReduction. + if (isCgaSmemReduction(selectKernelParams.mMultiCtasKvMode)) { + int maxActiveClusters = 1; + cuErrCheck(cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &launch_config)); + if (maxActiveClusters * ctaLaunchParams.mClusterDimX < + ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ) { + selectKernelParams.mForceGmemReduction = true; + selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReduction; + selectKernel(params, selectKernelParams); + std::tie(func, kernelMeta) = loadKernel(params, selectKernelParams); + computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParams); + FLASHINFER_CHECK(!selectKernelParams.mSelectNewKernel, + "trtllm-gen kernel selection did not converge after CgaSmemReduction " + "fallback to GmemReduction."); + // Rebuild kernelParams: setKernelParams uses kernelMeta (TMA descriptors, tile shapes) + // which changed when switching from CgaSmemReduction to GmemReduction kernel. + kernelParams = KernelParams::setKernelParams( + params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); + buildLaunchConfig(launch_config, launch_attribute, kernelMeta, ctaLaunchParams, params); + setNonPortableClusterIfNeeded(func, ctaLaunchParams); } - // Break the while op. - break; + } + + cuErrCheck(cuLaunchKernelEx(&launch_config, func, kernelParamsList, nullptr)); + + // Run the separate reduction kernel if needed. + tensorrt_llm::kernels::runFmhaReduction(kernelMeta, kernelParams, params.mMultiProcessorCount, + params.enable_pdl, params.stream); + + if (params.lsePtr != nullptr) { + flashinfer::ComputeLSEFromMD(params.softmaxStatsPtr, params.lsePtr, + params.mSumOfSeqLensQ * params.mNumHeadsQ, params.enable_pdl, + params.stream); } } private: + // Fill a CUlaunchConfig and its associated attribute array from the current kernel and CTA + // params. The caller owns the storage for launch_attribute (must be an array of at least 3 + // elements) and is responsible for ensuring it outlives launch_config. + void buildLaunchConfig(CUlaunchConfig& launch_config, CUlaunchAttribute* launch_attribute, + KernelMeta const& kernelMeta, CtaLaunchParams const& ctaLaunchParams, + RunnerParams const& params) const { + launch_config.blockDimX = kernelMeta.mThreadsPerCTA; + launch_config.blockDimY = 1; + launch_config.blockDimZ = 1; + launch_config.gridDimX = ctaLaunchParams.mNumCtasX; + launch_config.gridDimY = ctaLaunchParams.mNumCtasY; + launch_config.gridDimZ = ctaLaunchParams.mNumCtasZ; + launch_config.hStream = params.stream; + launch_config.sharedMemBytes = kernelMeta.mSharedMemBytes; + launch_attribute[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launch_attribute[0].value.clusterDim.x = ctaLaunchParams.mClusterDimX; + launch_attribute[0].value.clusterDim.y = 1; + launch_attribute[0].value.clusterDim.z = 1; + launch_attribute[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launch_attribute[1].value.clusterSchedulingPolicyPreference = + ctaLaunchParams.mClusterDimX > 1 ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD + : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; + launch_attribute[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + launch_attribute[2].value.programmaticStreamSerializationAllowed = params.enable_pdl; + launch_config.attrs = launch_attribute; + launch_config.numAttrs = 3; + } + + // Enable non-portable cluster sizes when clusterDimX exceeds the portable limit of 8. + void setNonPortableClusterIfNeeded(CUfunction func, + CtaLaunchParams const& ctaLaunchParams) const { + if (ctaLaunchParams.mClusterDimX > 8) { + cuErrCheck(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1 // Enable non-portable cluster sizes + )); + } + } + // Is it MLA generation kernel ? inline bool isMlaGenKernel(RunnerParams const& params) const { return params.mHeadDimQk == 576 && params.mHeadDimV == 512; @@ -540,15 +551,11 @@ class TllmGenFmhaKernel { if (params.mNumHeadsQPerKv <= 32) { kernelType = FmhaKernelType::SwapsMmaAbForGeneration; selectKernelParams.mTileSizeKv = 128; - // Only set GmemReduction on the first selection pass. - // computeCtaAndClusterConfig may upgrade it to CgaSmemReduction and set mSelectNewKernel=true; - // preserving the updated mode on re-selection avoids an infinite loop. - if (!selectKernelParams.mSelectNewKernel) { - selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReduction; - } - // The base tileSizeQ is numHeadsQPerKv (one CTA covers all Q heads per token). At batch=1 - // the GPU is under-utilized, so we halve tileSizeQ to create 2x more head-splitting CTAs. - // Threshold: batchSize * maxNumCtasPerSeqKv <= MP/8. + // mMultiCtasKvMode defaults to GmemReduction from the constructor. computeCtaAndClusterConfig + // may upgrade it to CgaSmemReduction; that update is preserved naturally across + // re-selections. The base tileSizeQ is numHeadsQPerKv (one CTA covers all Q heads per token). + // At batch=1 the GPU is under-utilized, so we halve tileSizeQ to create 2x more + // head-splitting CTAs. Threshold: batchSize * maxNumCtasPerSeqKv <= MP/8. // effectiveSeqLenKv = min(seqLen, topK) = 2048 -> maxNumCtasPerSeqKv = 16. // Condition: batchSize * 16 <= MP/8 -> batchSize <= 1 (crossover at batch=1->2). // Only halve when half tileSizeQ >= 8 (no valid SwapsMmaAb kernel below tileSizeQ=8). @@ -557,19 +564,18 @@ class TllmGenFmhaKernel { int const effectiveSeqLenKv = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK); int const maxNumCtasPerSeqKv = flashinfer::ceil_div(effectiveSeqLenKv, selectKernelParams.mTileSizeKv); - bool const useHalfTileSizeQ = - halfTileSizeQ >= 8 && - params.mBatchSize * maxNumCtasPerSeqKv <= params.mMultiProcessorCount / 8; + bool const useHalfTileSizeQ = halfTileSizeQ >= 8 && params.mBatchSize * maxNumCtasPerSeqKv <= + params.mMultiProcessorCount / 8; tileSizeQ = useHalfTileSizeQ ? halfTileSizeQ : fullTileSizeQ; } else { // numHeadsQ >= 64: use KeepsMmaAbForGeneration. kernelType = FmhaKernelType::KeepsMmaAbForGeneration; tileSizeQ = 64; selectKernelParams.mTileSizeKv = 128; - // Only set GmemReductionWithSeparateKernel on the first selection pass. - // computeCtaAndClusterConfig may disable it (numCtasPerSeqKv==1) and set mSelectNewKernel=true; - // preserving the updated Disabled mode on re-selection avoids an infinite loop. - if (!selectKernelParams.mSelectNewKernel) { + // Upgrade GmemReduction (constructor default) to GmemReductionWithSeparateKernel. + // If computeCtaAndClusterConfig already set it to Disabled (numCtasPerSeqKv==1), the + // isGmemReduction() guard is false and the Disabled state is preserved on re-selection. + if (isGmemReduction(selectKernelParams.mMultiCtasKvMode)) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; } // For numHeadsQ=128, use 2CTA when there are enough CTAs to amortize 2CTA overhead. @@ -584,11 +590,6 @@ class TllmGenFmhaKernel { if (use2Cta) { selectKernelParams.mUses2CtaMma = true; selectKernelParams.mHeadDimPerCtaV = 256; - } else if (!selectKernelParams.mSelectNewKernel) { - // Only set headDimPerCtaV on the first selection pass. - // computeCtaAndClusterConfig may reduce it and set mSelectNewKernel=true; - // preserving the updated value on re-selection avoids an infinite loop. - selectKernelParams.mHeadDimPerCtaV = 512; } } } @@ -605,8 +606,7 @@ class TllmGenFmhaKernel { selectSparseMlaGenerationKernel(params, selectKernelParams); } else { // Non-sparse MLA: use SwapsMmaAb when numHeadsQPerKv <= 32 or seqLenPerCtaKv is small. - bool const useSwapsMmaAb = - params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params); + bool const useSwapsMmaAb = params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params); if (useSwapsMmaAb) { kernelType = FmhaKernelType::SwapsMmaAbForGeneration; From a7428facc537b771bc54fab5783dc98eacaa0af1 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Fri, 20 Mar 2026 05:27:43 +0000 Subject: [PATCH 3/4] chore: Update trtllm-gen FMHA artifact checksum --- flashinfer/artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 0b517ab023..d0017c6630 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -155,7 +155,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "8e99084003b6bbc07a9ea61822c32de649254594065cbc52ebb020e2b4ef1752" + "5bd87798e560a63e883902fc5468146ffff0d3551bf337d2f81bd02893e9dc39" ) TRTLLM_GEN_BMM: str = ( "0af823880730c4f0b3832d2208fab035946694b83444410b9309db5613d60195" From e2b581945b3f32f0d6f16d80f2c31c04802566d3 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Fri, 20 Mar 2026 05:36:18 +0000 Subject: [PATCH 4/4] chore: Remove bench_sparse_mla.py benchmark script --- benchmarks/bench_sparse_mla.py | 250 --------------------------------- 1 file changed, 250 deletions(-) delete mode 100644 benchmarks/bench_sparse_mla.py diff --git a/benchmarks/bench_sparse_mla.py b/benchmarks/bench_sparse_mla.py deleted file mode 100644 index f6f20b8514..0000000000 --- a/benchmarks/bench_sparse_mla.py +++ /dev/null @@ -1,250 +0,0 @@ -""" -Benchmark for sparse MLA (trtllm-gen backend) across a grid of: - batch_size : 1, 32, 128, 512 - seqlen_kv : 1024, 2048, 4096, 8192, 32768 - num_heads_q : 16, 32, 64, 128 - dtype : bf16 (query+kv+out), e4m3 (query+kv, bf16 out) - -DeepSeek-V3 sparse MLA config: - kv_lora_rank = 512, qk_rope_head_dim = 64, qk_nope_head_dim = 512 - sparse_mla_top_k = min(2048, seqlen_kv) - page_size = 32 -""" - -import csv -import math -import random -import sys -from datetime import datetime - -import torch - -import flashinfer -from flashinfer.testing.utils import bench_gpu_time -from flashinfer.utils import get_compute_capability - -# --------------------------------------------------------------------------- -# DeepSeek-V3 MLA dims -# --------------------------------------------------------------------------- -KV_LORA_RANK = 512 -QK_ROPE_HEAD_DIM = 64 -QK_NOPE_HEAD_DIM = KV_LORA_RANK # = 512 -QK_HEAD_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM # = 576 -PAGE_SIZE = 32 -SPARSE_TOP_K_MAX = 2048 - -# --------------------------------------------------------------------------- -# Sweep parameters -# --------------------------------------------------------------------------- -BATCH_SIZES = [1, 32, 128, 512] -SEQLEN_KVS = [1024, 2048, 4096, 8192, 32768] -NUM_HEADS_Q_LIST = [16, 32, 64, 128] -DTYPES = [ - ("bf16", torch.bfloat16, torch.bfloat16), # (tag, q_dtype, kv_dtype) - ("e4m3", torch.float8_e4m3fn, torch.float8_e4m3fn), -] - -NUM_ITERS = 30 -DRY_RUN_ITERS = 5 - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- -def generate_sparse_indices( - batch_size, q_len, seq_lens, topk, page_size, block_tables, device -): - """Returns indices_in_kvcache: [batch_size, q_len, topk] pointing into the flat KV pool.""" - block_tables_cpu = block_tables.cpu() - seq_lens_cpu = seq_lens.cpu() - - indices_in_kvcache = torch.empty( - batch_size, q_len, topk, dtype=torch.int32, device="cpu" - ) - - for i in range(batch_size): - cur_seq_len = int(seq_lens_cpu[i].item()) - actual_topk = min(topk, cur_seq_len) - for j in range(q_len): - cur_abs = torch.arange(0, actual_topk, device="cpu") - cur_blocked = block_tables_cpu[i, cur_abs // page_size] * page_size + ( - cur_abs % page_size - ) - if actual_topk < topk: - pad = torch.full( - (topk - actual_topk,), -1, dtype=torch.int32, device="cpu" - ) - cur_blocked = torch.cat([cur_blocked, pad]) - indices_in_kvcache[i, j, :] = cur_blocked - - return indices_in_kvcache.to(device) - - -def setup_inputs(batch_size, seqlen_kv, num_heads_q, q_dtype, kv_dtype, device): - """Create all tensors needed for a sparse MLA decode call.""" - topk = min(SPARSE_TOP_K_MAX, seqlen_kv) - q_len = 1 # decode phase - - # Query: [B, q_len, H, QK_HEAD_DIM] - query = torch.randn(batch_size, q_len, num_heads_q, QK_HEAD_DIM, device=device) - query.clamp_(-1.0, 1.0) - query = query.to(q_dtype) - - # KV cache pool - seq_lens = torch.full((batch_size,), seqlen_kv, dtype=torch.int32, device=device) - blocks_per_seq = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE - total_blocks = int(blocks_per_seq.sum().item()) - - all_block_ids = torch.randperm(total_blocks, device=device) - max_blocks = int(blocks_per_seq.max().item()) - block_tables = torch.zeros(batch_size, max_blocks, dtype=torch.int32, device=device) - bid = 0 - for i in range(batch_size): - nb = int(blocks_per_seq[i].item()) - block_tables[i, :nb] = all_block_ids[bid : bid + nb] - bid += nb - - kv_cache = torch.randn(total_blocks, PAGE_SIZE, QK_HEAD_DIM, device=device) - kv_cache.clamp_(-1.0, 1.0) - kv_cache = kv_cache.to(kv_dtype) - - # Sparse indices: [B, q_len, topk] - indices_in_kvcache = generate_sparse_indices( - batch_size, q_len, seq_lens, topk, PAGE_SIZE, block_tables, device - ) - - # Workspace (zero-initialised, as required) - workspace = torch.zeros(256 * 1024 * 1024, dtype=torch.int8, device=device) - - bmm1_scale = 1.0 / math.sqrt(QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM) - bmm2_scale = 1.0 - - return dict( - query=query, - kv_cache=kv_cache.unsqueeze(1), # [blocks, 1, page_size, head_dim] - workspace_buffer=workspace, - qk_nope_head_dim=QK_NOPE_HEAD_DIM, - kv_lora_rank=KV_LORA_RANK, - qk_rope_head_dim=QK_ROPE_HEAD_DIM, - block_tables=indices_in_kvcache, - seq_lens=seq_lens, - max_seq_len=seqlen_kv, - sparse_mla_top_k=topk, - bmm1_scale=bmm1_scale, - bmm2_scale=bmm2_scale, - backend="trtllm-gen", - ) - - -def run_one(batch_size, seqlen_kv, num_heads_q, dtype_tag, q_dtype, kv_dtype, device): - topk = min(SPARSE_TOP_K_MAX, seqlen_kv) - kwargs = setup_inputs(batch_size, seqlen_kv, num_heads_q, q_dtype, kv_dtype, device) - - # Warmup + measure - measurements = bench_gpu_time( - flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla, - dry_run_iters=DRY_RUN_ITERS, - repeat_iters=NUM_ITERS, - enable_cupti=True, - use_cuda_graph=True, - input_kwargs=kwargs, - ) - median_ms = float(torch.tensor(measurements).median().item()) - std_ms = float(torch.tensor(measurements).float().std().item()) - - # Memory-bandwidth estimate: kv bytes accessed - def elem_bytes(dtype): - return torch.empty(1, dtype=dtype).element_size() - - kv_bytes = batch_size * topk * QK_HEAD_DIM * elem_bytes(kv_dtype) - q_bytes = batch_size * num_heads_q * QK_HEAD_DIM * elem_bytes(q_dtype) - o_bytes = batch_size * num_heads_q * KV_LORA_RANK * 2 # bf16 output always - total_bytes = kv_bytes + q_bytes + o_bytes - bw_tbs = total_bytes / median_ms / 1e9 - - print( - f"bs={batch_size:4d} seqkv={seqlen_kv:6d} H={num_heads_q:3d} " - f"dtype={dtype_tag} topk={topk:5d} " - f"median={median_ms:.3f}ms std={std_ms:.3f}ms bw={bw_tbs:.2f}TB/s" - ) - return dict( - batch_size=batch_size, - seqlen_kv=seqlen_kv, - num_heads_q=num_heads_q, - dtype=dtype_tag, - sparse_top_k=topk, - median_ms=median_ms, - std_ms=std_ms, - bw_tbs=bw_tbs, - ) - - -def main(): - device = torch.device("cuda:0") - cc = get_compute_capability(device) - if cc[0] != 10: - print( - f"ERROR: trtllm-gen sparse MLA requires SM100/SM103, got SM{cc[0]}{cc[1]}" - ) - sys.exit(1) - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - csv_path = f"bench_sparse_mla_{timestamp}.csv" - - fieldnames = [ - "batch_size", - "seqlen_kv", - "num_heads_q", - "dtype", - "sparse_top_k", - "median_ms", - "std_ms", - "bw_tbs", - ] - - results = [] - total = len(BATCH_SIZES) * len(SEQLEN_KVS) * len(NUM_HEADS_Q_LIST) * len(DTYPES) - done = 0 - - print(f"Running {total} configurations. Results -> {csv_path}\n") - print( - f"{'bs':>5} {'seqkv':>7} {'H':>4} {'dtype':>5} {'topk':>6} " - f"{'median_ms':>10} {'std_ms':>8} {'bw_TB/s':>9}" - ) - print("-" * 65) - - with open(csv_path, "w", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - - for dtype_tag, q_dtype, kv_dtype in DTYPES: - for num_heads_q in NUM_HEADS_Q_LIST: - for seqlen_kv in SEQLEN_KVS: - for batch_size in BATCH_SIZES: - try: - row = run_one( - batch_size, - seqlen_kv, - num_heads_q, - dtype_tag, - q_dtype, - kv_dtype, - device, - ) - results.append(row) - writer.writerow(row) - f.flush() - except Exception as e: - print( - f" SKIP bs={batch_size} seqkv={seqlen_kv} " - f"H={num_heads_q} dtype={dtype_tag}: {e}" - ) - done += 1 - - print(f"\nDone. {len(results)}/{total} succeeded. CSV saved to {csv_path}") - - -if __name__ == "__main__": - torch.manual_seed(42) - random.seed(42) - main()