-
Notifications
You must be signed in to change notification settings - Fork 813
Open
Labels
Description
Not an urgent issue, just came across this in a random test
import numpy as np
import torch
from triton.testing import do_bench
import flashinfer
# page_block_size = 16
page_block_size = 1
num_kv_heads = 4
num_qo_heads = 32
head_dim = 128
def bench_batch_decode(
batch_size,
seq_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
q_dtype,
kv_dtype,
):
np.random.seed(42)
seq_lens = torch.full((batch_size,), seq_len)
seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
kv_indptr = torch.cat([torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0)
kv_indptr = kv_indptr.int()
last_page_len = seq_lens - (seq_lens_blocks - 1) * page_block_size
last_page_len = last_page_len.int()
num_blocks = kv_indptr[-1].item()
q = torch.rand(batch_size, num_qo_heads, head_dim, dtype=q_dtype, device="cuda:0")
kv_data = torch.randn(
num_blocks, 2, page_block_size, num_kv_heads, head_dim, device="cuda:0"
).to(kv_dtype)
workspace_buffer = torch.empty(
128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"
)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_tensor_cores=True
)
wrapper.plan(
kv_indptr.to(0),
torch.arange(num_blocks).int().to(0),
last_page_len.to(0),
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
data_type=kv_dtype,
q_data_type=q_dtype,
)
ms = do_bench(lambda: wrapper.run(q, kv_data))
wrapper_ragged = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"),
kv_layout="NHD",
backend="fa3",
)
qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int()
wrapper_ragged.plan(
qo_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
causal=True,
q_data_type=q_dtype,
kv_data_type=kv_dtype,
)
k, v = kv_data.chunk(2, dim=1)
k = k.squeeze()
v = v.squeeze()
ms_ragged = do_bench(lambda: wrapper_ragged.run(q, k, v))
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
print(f"max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024 :.2f} MB")
print(
f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_block_size={page_block_size}, q_dtype={q_dtype}, kv_dtype={kv_dtype}"
)
print(
f"Batch Decode execution time: {ms}ms, Ragged prefill execution time: {ms_ragged}ms"
)
print(
f"Batch Decode memory bandwidth: {io / ms / 1024 / 1024 :.2f} GB/s, Ragged prefill memory bandwidth: {io / ms_ragged / 1024 / 1024 :.2f} GB/s"
)
if __name__ == "__main__":
for q_dtype in [torch.bfloat16]:
for kv_dtype in [torch.bfloat16, torch.float8_e4m3fn]:
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
for seq_len in [16384]:
bench_batch_decode(
batch_size,
seq_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
q_dtype,
kv_dtype,
)
Reactions are currently unavailable