Skip to content

Commit 68d1608

Browse files
yyihuangyzh119
andauthored
minor: add trtllm_gen_mla benchmark (#1316)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Add a missing benchmark. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent 12d48c6 commit 68d1608

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed

β€Žbenchmarks/bench_trtllm_gen_mla.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import numpy as np
2+
import torch
3+
import triton
4+
5+
import flashinfer
6+
from flashinfer.testing.utils import bench_gpu_time, bench_gpu_time_with_cudagraph
7+
8+
num_q_heads = 128
9+
num_kv_heads = 1
10+
qk_nope_head_dim = 128
11+
qk_rope_head_dim = 64
12+
kv_lora_rank = 512
13+
14+
15+
def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
16+
torch.manual_seed(42)
17+
device = "cuda:0"
18+
19+
# Initialize tensors
20+
query = torch.randn(
21+
batch_size,
22+
q_len_per_request,
23+
num_q_heads,
24+
kv_lora_rank + qk_rope_head_dim,
25+
device=device,
26+
).to(dtype)
27+
28+
num_tokens = seq_len * batch_size
29+
num_blocks = (num_tokens + page_size - 1) // page_size
30+
31+
# Sequence lengths and block tables
32+
seq_lens = [torch.randint(1, seq_len, (1,)).item() for _ in range(batch_size)]
33+
seq_lens[-1] = seq_len
34+
max_seq_len = max(seq_lens)
35+
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device)
36+
37+
blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size
38+
max_num_blocks_per_seq = blocks_per_seq.max().item()
39+
40+
# Generate random but unique block IDs for all sequences
41+
total_blocks_needed = sum(blocks_per_seq)
42+
all_block_ids = torch.randperm(
43+
total_blocks_needed, device=device
44+
) # Random permutation
45+
46+
# Generate unique block IDs for all sequences
47+
block_id = 0
48+
block_tables = torch.zeros(
49+
(batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device
50+
)
51+
52+
# Populate block tables and track block assignments
53+
block_id = 0
54+
for i in range(batch_size):
55+
num_blocks_needed = blocks_per_seq[i]
56+
block_tables[i, :num_blocks_needed] = all_block_ids[
57+
block_id : block_id + num_blocks_needed
58+
]
59+
block_id += num_blocks_needed
60+
61+
# Create interleaved KV cache
62+
# Allocate more than needed blocks, block_id is just enough, to mimick real-world cases
63+
kv_cache = torch.randn(
64+
size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device
65+
).to(dtype)
66+
# (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim)
67+
68+
# Allocate workspace buffer
69+
# todo(Yingyi): calculate the actual size of workspace buffer
70+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
71+
72+
# Run decode-MLA
73+
# warmup
74+
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
75+
query=query,
76+
kv_cache=kv_cache.unsqueeze(1),
77+
workspace_buffer=workspace_buffer,
78+
qk_nope_head_dim=qk_nope_head_dim,
79+
kv_lora_rank=kv_lora_rank,
80+
qk_rope_head_dim=qk_rope_head_dim,
81+
block_tables=block_tables,
82+
seq_lens=seq_lens_tensor,
83+
max_seq_len=max_seq_len,
84+
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
85+
bmm2_scale=1.0,
86+
)
87+
# benchmark
88+
measurements = bench_gpu_time_with_cudagraph(
89+
lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
90+
query=query,
91+
kv_cache=kv_cache.unsqueeze(1),
92+
workspace_buffer=workspace_buffer,
93+
qk_nope_head_dim=qk_nope_head_dim,
94+
kv_lora_rank=kv_lora_rank,
95+
qk_rope_head_dim=qk_rope_head_dim,
96+
block_tables=block_tables,
97+
seq_lens=seq_lens_tensor,
98+
max_seq_len=max_seq_len,
99+
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
100+
bmm2_scale=1.0,
101+
),
102+
dry_run_time_ms=100,
103+
repeat_time_ms=1000,
104+
)
105+
io = (
106+
query.numel() * query.element_size()
107+
+ kv_cache.numel() * kv_cache.element_size()
108+
)
109+
ms = np.median(measurements)
110+
flops = (
111+
2
112+
* batch_size
113+
* num_q_heads
114+
* (2 * kv_lora_rank + qk_rope_head_dim)
115+
* seq_len
116+
* q_len_per_request
117+
)
118+
print(
119+
f"batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, num_kv_heads={num_kv_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}"
120+
)
121+
print(f"execution time: {ms} ms")
122+
print(f"memory bandwidth: {io / ms / 1024 / 1024 :.2f} GB/s")
123+
print(f"FLOPs: {flops * 1e-9 / ms:.2f} TFLOPs/s")
124+
125+
126+
if __name__ == "__main__":
127+
for dtype in [torch.bfloat16, torch.float8_e4m3fn]:
128+
for page_size in [32, 64]:
129+
for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]:
130+
for seq_len in [1024, 4096, 8192]:
131+
for q_len_per_request in [1, 2, 4, 8, 16]:
132+
bench_trtllm_mla(
133+
batch_size, q_len_per_request, seq_len, page_size, dtype
134+
)

0 commit comments

Comments
Β (0)