|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +import triton |
| 4 | + |
| 5 | +import flashinfer |
| 6 | +from flashinfer.testing.utils import bench_gpu_time, bench_gpu_time_with_cudagraph |
| 7 | + |
| 8 | +num_q_heads = 128 |
| 9 | +num_kv_heads = 1 |
| 10 | +qk_nope_head_dim = 128 |
| 11 | +qk_rope_head_dim = 64 |
| 12 | +kv_lora_rank = 512 |
| 13 | + |
| 14 | + |
| 15 | +def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): |
| 16 | + torch.manual_seed(42) |
| 17 | + device = "cuda:0" |
| 18 | + |
| 19 | + # Initialize tensors |
| 20 | + query = torch.randn( |
| 21 | + batch_size, |
| 22 | + q_len_per_request, |
| 23 | + num_q_heads, |
| 24 | + kv_lora_rank + qk_rope_head_dim, |
| 25 | + device=device, |
| 26 | + ).to(dtype) |
| 27 | + |
| 28 | + num_tokens = seq_len * batch_size |
| 29 | + num_blocks = (num_tokens + page_size - 1) // page_size |
| 30 | + |
| 31 | + # Sequence lengths and block tables |
| 32 | + seq_lens = [torch.randint(1, seq_len, (1,)).item() for _ in range(batch_size)] |
| 33 | + seq_lens[-1] = seq_len |
| 34 | + max_seq_len = max(seq_lens) |
| 35 | + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) |
| 36 | + |
| 37 | + blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size |
| 38 | + max_num_blocks_per_seq = blocks_per_seq.max().item() |
| 39 | + |
| 40 | + # Generate random but unique block IDs for all sequences |
| 41 | + total_blocks_needed = sum(blocks_per_seq) |
| 42 | + all_block_ids = torch.randperm( |
| 43 | + total_blocks_needed, device=device |
| 44 | + ) # Random permutation |
| 45 | + |
| 46 | + # Generate unique block IDs for all sequences |
| 47 | + block_id = 0 |
| 48 | + block_tables = torch.zeros( |
| 49 | + (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device |
| 50 | + ) |
| 51 | + |
| 52 | + # Populate block tables and track block assignments |
| 53 | + block_id = 0 |
| 54 | + for i in range(batch_size): |
| 55 | + num_blocks_needed = blocks_per_seq[i] |
| 56 | + block_tables[i, :num_blocks_needed] = all_block_ids[ |
| 57 | + block_id : block_id + num_blocks_needed |
| 58 | + ] |
| 59 | + block_id += num_blocks_needed |
| 60 | + |
| 61 | + # Create interleaved KV cache |
| 62 | + # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases |
| 63 | + kv_cache = torch.randn( |
| 64 | + size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device |
| 65 | + ).to(dtype) |
| 66 | + # (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim) |
| 67 | + |
| 68 | + # Allocate workspace buffer |
| 69 | + # todo(Yingyi): calculate the actual size of workspace buffer |
| 70 | + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) |
| 71 | + |
| 72 | + # Run decode-MLA |
| 73 | + # warmup |
| 74 | + flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( |
| 75 | + query=query, |
| 76 | + kv_cache=kv_cache.unsqueeze(1), |
| 77 | + workspace_buffer=workspace_buffer, |
| 78 | + qk_nope_head_dim=qk_nope_head_dim, |
| 79 | + kv_lora_rank=kv_lora_rank, |
| 80 | + qk_rope_head_dim=qk_rope_head_dim, |
| 81 | + block_tables=block_tables, |
| 82 | + seq_lens=seq_lens_tensor, |
| 83 | + max_seq_len=max_seq_len, |
| 84 | + bmm1_scale=1.0 / ((128 + 64) ** 0.5), |
| 85 | + bmm2_scale=1.0, |
| 86 | + ) |
| 87 | + # benchmark |
| 88 | + measurements = bench_gpu_time_with_cudagraph( |
| 89 | + lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( |
| 90 | + query=query, |
| 91 | + kv_cache=kv_cache.unsqueeze(1), |
| 92 | + workspace_buffer=workspace_buffer, |
| 93 | + qk_nope_head_dim=qk_nope_head_dim, |
| 94 | + kv_lora_rank=kv_lora_rank, |
| 95 | + qk_rope_head_dim=qk_rope_head_dim, |
| 96 | + block_tables=block_tables, |
| 97 | + seq_lens=seq_lens_tensor, |
| 98 | + max_seq_len=max_seq_len, |
| 99 | + bmm1_scale=1.0 / ((128 + 64) ** 0.5), |
| 100 | + bmm2_scale=1.0, |
| 101 | + ), |
| 102 | + dry_run_time_ms=100, |
| 103 | + repeat_time_ms=1000, |
| 104 | + ) |
| 105 | + io = ( |
| 106 | + query.numel() * query.element_size() |
| 107 | + + kv_cache.numel() * kv_cache.element_size() |
| 108 | + ) |
| 109 | + ms = np.median(measurements) |
| 110 | + flops = ( |
| 111 | + 2 |
| 112 | + * batch_size |
| 113 | + * num_q_heads |
| 114 | + * (2 * kv_lora_rank + qk_rope_head_dim) |
| 115 | + * seq_len |
| 116 | + * q_len_per_request |
| 117 | + ) |
| 118 | + print( |
| 119 | + f"batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, num_kv_heads={num_kv_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}" |
| 120 | + ) |
| 121 | + print(f"execution time: {ms} ms") |
| 122 | + print(f"memory bandwidth: {io / ms / 1024 / 1024 :.2f} GB/s") |
| 123 | + print(f"FLOPs: {flops * 1e-9 / ms:.2f} TFLOPs/s") |
| 124 | + |
| 125 | + |
| 126 | +if __name__ == "__main__": |
| 127 | + for dtype in [torch.bfloat16, torch.float8_e4m3fn]: |
| 128 | + for page_size in [32, 64]: |
| 129 | + for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]: |
| 130 | + for seq_len in [1024, 4096, 8192]: |
| 131 | + for q_len_per_request in [1, 2, 4, 8, 16]: |
| 132 | + bench_trtllm_mla( |
| 133 | + batch_size, q_len_per_request, seq_len, page_size, dtype |
| 134 | + ) |
0 commit comments