Skip to content

Ragged prefill SM90 fails with illegal memory access for total seqlen >= 16384 #1419

@Edenzzzz

Description

@Edenzzzz

Not an urgent issue, just came across this in a random test

Image
import numpy as np
import torch
from triton.testing import do_bench

import flashinfer

# page_block_size = 16
page_block_size = 1
num_kv_heads = 4
num_qo_heads = 32
head_dim = 128


def bench_batch_decode(
    batch_size,
    seq_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_block_size,
    q_dtype,
    kv_dtype,
):
    np.random.seed(42)
    seq_lens = torch.full((batch_size,), seq_len)
    seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
    kv_indptr = torch.cat([torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0)
    kv_indptr = kv_indptr.int()
    last_page_len = seq_lens - (seq_lens_blocks - 1) * page_block_size
    last_page_len = last_page_len.int()
    num_blocks = kv_indptr[-1].item()

    q = torch.rand(batch_size, num_qo_heads, head_dim, dtype=q_dtype, device="cuda:0")
    kv_data = torch.randn(
        num_blocks, 2, page_block_size, num_kv_heads, head_dim, device="cuda:0"
    ).to(kv_dtype)
    workspace_buffer = torch.empty(
        128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"
    )
    wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer, kv_layout="NHD", use_tensor_cores=True
    )
    wrapper.plan(
        kv_indptr.to(0),
        torch.arange(num_blocks).int().to(0),
        last_page_len.to(0),
        num_qo_heads,
        num_kv_heads,
        head_dim,
        page_block_size,
        data_type=kv_dtype,
        q_data_type=q_dtype,
    )

    ms = do_bench(lambda: wrapper.run(q, kv_data))

    wrapper_ragged = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
        torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"),
        kv_layout="NHD",
        backend="fa3",
    )
    qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int()
    wrapper_ragged.plan(
        qo_indptr,
        kv_indptr,
        num_qo_heads,
        num_kv_heads,
        head_dim,
        causal=True,
        q_data_type=q_dtype,
        kv_data_type=kv_dtype,
    )
    k, v = kv_data.chunk(2, dim=1)
    k = k.squeeze()
    v = v.squeeze()
    ms_ragged = do_bench(lambda: wrapper_ragged.run(q, k, v))

    io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
    print(f"max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024 :.2f} MB")
    print(
        f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_block_size={page_block_size}, q_dtype={q_dtype}, kv_dtype={kv_dtype}"
    )
    print(
        f"Batch Decode execution time: {ms}ms, Ragged prefill execution time: {ms_ragged}ms"
    )
    print(
        f"Batch Decode memory bandwidth: {io / ms / 1024 / 1024 :.2f} GB/s, Ragged prefill memory bandwidth: {io / ms_ragged / 1024 / 1024 :.2f} GB/s"
    )


if __name__ == "__main__":
    for q_dtype in [torch.bfloat16]:
        for kv_dtype in [torch.bfloat16, torch.float8_e4m3fn]:
            for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
                for seq_len in [16384]:
                    bench_batch_decode(
                        batch_size,
                        seq_len,
                        num_qo_heads,
                        num_kv_heads,
                        head_dim,
                        page_block_size,
                        q_dtype,
                        kv_dtype,
                    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions