diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 6969167e6..54b065667 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1288,6 +1288,13 @@ def run( self._cached_module.paged_run(*run_args) else: + # trtllm-gen does not need plan info + if self._backend == "trtllm-gen" and self._plan_info is None: + plan_info: List[int] = [] + else: + plan_info = self._plan_info + assert plan_info is not None, "plan info is not initialized" + run_args = [ self._float_workspace_buffer, self._int_workspace_buffer, diff --git a/scripts/run_test_blackwell_attention_kernels.sh b/scripts/run_test_blackwell_attention_kernels.sh index dd5e23131..1eddcd5e1 100644 --- a/scripts/run_test_blackwell_attention_kernels.sh +++ b/scripts/run_test_blackwell_attention_kernels.sh @@ -7,8 +7,8 @@ pytest -s tests/test_blackwell_fmha.py pytest -s tests/test_deepseek_mla.py # trtllm-gen -pytest -s tests/test_trtllm_gen_context.py -pytest -s tests/test_trtllm_gen_decode.py +pytest -s tests/test_trtllm_gen_attention.py +pytest -s tests/test_trtllm_gen_mla.py # cudnn pytest -s tests/test_cudnn_decode.py diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py new file mode 100644 index 000000000..502bec389 --- /dev/null +++ b/tests/test_trtllm_gen_attention.py @@ -0,0 +1,535 @@ +import math + +import pytest +import torch +from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant + +import flashinfer +from flashinfer.utils import FP4Tensor + +DTYPE_MAP = { + "half": torch.float16, + "bf16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "nvfp4": "nvfp4", +} + +GPU_DEVICE = "cuda:0" + +global_workspace_buffer = None + + +def flip_coin(*args, **kwargs): + # Use any test parameters to deterministically decide branch + # This makes test configurations go through different paths + param_tuple = args + tuple(sorted(kwargs.items())) + hash_value = hash(param_tuple) + return (hash_value % 2) == 0 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +def generate_seq_lens(batch_size, max_q_len, max_in_kv_len): + q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len + in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int) + in_kv_lens[-1] = max_in_kv_len + seq_lens = q_lens + in_kv_lens + return q_lens, in_kv_lens, seq_lens + + +def generate_cumsum_lens(lens): + return torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=GPU_DEVICE), + torch.cumsum(lens.to(GPU_DEVICE), dim=0, dtype=torch.int32), + ] + ) + + +def create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype): + q = torch.randn( + torch.sum(q_lens).item(), + num_qo_heads, + head_dim, + dtype=torch.bfloat16 if q_dtype == "fp8" else DTYPE_MAP[q_dtype], + device=GPU_DEVICE, + ) + if q_dtype == "fp8": + q, q_scale = to_float8(q) + # Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead. + ref_q = q.bfloat16() * q_scale + else: + q_scale = 1.0 + ref_q = q + + return q, q_scale, ref_q + + +def create_kv_cache( + batch_size, seq_lens, page_size, num_kv_heads, head_dim, kv_dtype, ref_kv_dtype +): + # Create separate K and V caches + max_seq_len = torch.max(seq_lens).item() + num_tokens = max_seq_len * batch_size + num_pages = (num_tokens + page_size - 1) // page_size + ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype] + if kv_dtype != "fp8": # for fp8, create with high precision to generate scale. + assert kv_dtype == ref_kv_dtype, ( + "kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache" + ) + + k_cache = torch.randn( + num_pages, + num_kv_heads, + page_size, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + v_cache = torch.randn( + num_pages, + num_kv_heads, + page_size, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + + # Convert K and V separately to fp8 if needed + if kv_dtype == "fp8": + k_cache, k_scale = to_float8(k_cache) + v_cache, v_scale = to_float8(v_cache) + # use high precision and fake-quantization for reference to avoid precision/functional issue + ref_kv_cache = torch.stack( + [ + k_cache.to(ref_kv_dtype_torch) * k_scale, + v_cache.to(ref_kv_dtype_torch) * v_scale, + ], + dim=1, + ) + else: + k_scale = v_scale = 1.0 + ref_kv_cache = torch.stack([k_cache, v_cache], dim=1) + # Combine K and V into interleaved format for the API + kv_cache = torch.stack([k_cache, v_cache], dim=1) + + return kv_cache, k_scale, v_scale, ref_kv_cache + + +def create_page_table(batch_size, seq_lens, page_size): + page_per_seq = (seq_lens + page_size - 1) // page_size + max_num_pages_per_seq = torch.max(page_per_seq).item() + + # Generate random but unique page IDs for all sequences + total_pages_needed = torch.sum(page_per_seq).item() + all_page_ids = torch.randperm( + total_pages_needed, dtype=torch.int32, device=GPU_DEVICE + ) + + # Generate unique page IDs for all sequences + page_tables = torch.zeros( + (batch_size, max_num_pages_per_seq), dtype=torch.int32, device=GPU_DEVICE + ) + + # Populate page tables and track page assignments + page_id = 0 + for i in range(batch_size): + num_pages_needed = page_per_seq[i] + page_tables[i, :num_pages_needed] = all_page_ids[ + page_id : page_id + num_pages_needed + ] + page_id += num_pages_needed + return page_tables, all_page_ids, page_per_seq + + +def create_output(q, o_dtype, create_out_tensor): + if o_dtype == "fp8": + o_scale = torch.rand(1).item() * 0.5 + 0.5 # Scale range: 0.5 ~ 1.0 + else: + o_scale = 1.0 + o_sf_scale = ( + 300 if o_dtype == "nvfp4" else None + ) # choose a value to make error smaller by testing. + o_sf_vec_size = 16 if o_dtype == "nvfp4" else None + + if create_out_tensor: + if o_dtype == "nvfp4": + fp4_out_shape = q.shape[:-1] + (math.ceil(q.shape[-1] / 2),) + + fp4_out_scale_shape = ( + math.ceil(q.shape[0] / 128) * 128, + math.ceil(q.shape[1] * q.shape[2] / o_sf_vec_size / 4) * 4, + ) + + out_scale_factor = torch.empty( + fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=q.device + ) + extra_size = fp4_out_scale_shape[0] - q.shape[0] + o_sf_start_index = ( + torch.randint(0, extra_size, (1,)).item() if extra_size > 0 else 0 + ) + out_data = torch.empty(fp4_out_shape, dtype=torch.uint8, device=q.device) + out = FP4Tensor(out_data, out_scale_factor, o_sf_start_index) + else: + out = torch.empty_like(q, dtype=DTYPE_MAP[o_dtype]) + else: + out = None + return out, o_scale, o_sf_scale, o_sf_vec_size + + +def get_last_page_len(seq_lens, page_size): + kv_last_page_len = seq_lens % page_size + kv_last_page_len[kv_last_page_len == 0] = page_size + return kv_last_page_len + + +def unpack_compare_nvfp4( + output: FP4Tensor, + output_ref, + o_sf_scale, + o_sf_vec_size, + sf_rtol=2e-1, + sf_atol=2e-1, + rmse_tol=0.3, +): + output_ref, out_scale_factor_ref = ref_fp4_quant( + output_ref, o_sf_scale, o_sf_vec_size + ) + + output_unpacked = cast_from_fp4(output.data) + out_scale_factor = recover_swizzled_scales( + output.scale, + output_unpacked.shape[0], + math.prod(list(output_unpacked.shape[1:])), + o_sf_vec_size, + output.scale_start_index, + ) + + torch.testing.assert_close( + out_scale_factor.float().reshape(out_scale_factor_ref.shape), + out_scale_factor_ref.float(), + rtol=sf_rtol, + atol=sf_atol, + ) + rmse = torch.sqrt(torch.mean((output_unpacked.float() - output_ref.float()) ** 2)) + assert rmse.item() < rmse_tol + return output_unpacked, output_ref + + +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize("batch_size", [4, 128, 256]) +@pytest.mark.parametrize("page_size", [16, 32, 64]) +@pytest.mark.parametrize("num_kv_heads", [2, 4]) +@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) +@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("half", "half", "half"), + ("bf16", "bf16", "bf16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +def test_trtllm_batch_prefill( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, +): + # Set up test parameters + torch.manual_seed(0) + head_dim = 128 + MAX_Q_LEN = 511 + MAX_IN_KV_LEN = 2047 + + # Generate random sequence lengths + num_qo_heads = num_kv_heads * head_grp_size + q_lens, in_kv_lens, seq_lens = generate_seq_lens( + batch_size, MAX_Q_LEN, MAX_IN_KV_LEN + ) + + # Create query tensor and related data + q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype) + q_indptr = generate_cumsum_lens(q_lens) + + # Create KV cache and related data + kv_cache, k_scale, v_scale, ref_kv_cache = create_kv_cache( + batch_size, + seq_lens, + page_size, + num_kv_heads, + head_dim, + kv_dtype, + "bf16" if q_dtype == "fp8" else q_dtype, + ) + page_table, all_page_ids, page_per_seq = create_page_table( + batch_size, seq_lens, page_size + ) + kv_indptr = generate_cumsum_lens(page_per_seq) + kv_last_page_len = get_last_page_len(seq_lens, page_size) + + # Create output tensor and related data + create_out_tensor = flip_coin( + batch_size, page_size, num_kv_heads, head_grp_size, o_dtype + ) + out, o_scale, o_sf_scale, o_sf_vec_size = create_output( + q, o_dtype, create_out_tensor + ) + + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.zeros( + 128 * 1024 * 1024, dtype=torch.int8, device=GPU_DEVICE + ) + workspace_buffer = global_workspace_buffer + + # Run reference wrapper + wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + plan_params = { + "qo_indptr": q_indptr, + "paged_kv_indptr": kv_indptr, + "paged_kv_indices": all_page_ids, + "paged_kv_last_page_len": kv_last_page_len.to(GPU_DEVICE), + "num_qo_heads": num_qo_heads, + "num_kv_heads": num_kv_heads, + "head_dim_qk": head_dim, + "page_size": page_size, + "causal": True, + "pos_encoding_mode": "NONE", + "logits_soft_cap": 0.0, + "q_data_type": ref_q.dtype, + "kv_data_type": ref_kv_cache.dtype, + "window_left": window_left, + } + wrapper_ref.plan(**plan_params) + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + + # Run trtllm-gen function call + sm_scale = float(1.0 / (head_dim**0.5)) + output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( + q.contiguous(), + kv_cache, + workspace_buffer, + page_table, + seq_lens.to(GPU_DEVICE), + torch.max(q_lens).item(), + torch.max(seq_lens).item(), + q_scale * k_scale * sm_scale, # bmm1_scale + v_scale / o_scale, # bmm2_scale + batch_size, + q_indptr, + kv_indptr, + window_left, # window_left + out=out, + out_dtype=DTYPE_MAP[o_dtype], + o_sf_scale=o_sf_scale, + o_sf_vec_size=o_sf_vec_size, + ) + + if o_dtype == "nvfp4": + output, output_ref = unpack_compare_nvfp4( + output, output_ref, o_sf_scale, o_sf_vec_size + ) + assert o_scale == 1.0 + rtol, atol = 4e-1, 1e0 + elif o_dtype == "fp8": + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 + + # convert to float32 for fp8 is not supported by assert_close + torch.testing.assert_close( + output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol + ) + + if o_dtype != "nvfp4": # wrapper api does not support fp4 output yet. + # test wrapper with trtllm-gen backend + wrapper_trtllm_gen = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, backend="trtllm-gen" + ) + plan_params["q_data_type"] = q.dtype + plan_params["kv_data_type"] = kv_cache.dtype + wrapper_trtllm_gen.plan(**plan_params) + output_wrapper = wrapper_trtllm_gen.run( + q.contiguous(), + kv_cache, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale / o_scale, + ) + # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. + if v_scale == o_scale == 1.0: + assert (output_wrapper == output).all() + else: + torch.testing.assert_close( + output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 + ) + + +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize("batch_size", [4, 128, 256]) +@pytest.mark.parametrize("page_size", [16, 32, 64]) +@pytest.mark.parametrize("num_kv_heads", [2, 4]) +@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) +@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("half", "half", "half"), + ("half", "fp8", "half"), + ("bf16", "bf16", "bf16"), + ("bf16", "fp8", "bf16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +def test_trtllm_batch_decode( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, +): + # Set up test parameters + torch.manual_seed(0) + head_dim = 128 + MAX_Q_LEN = 1 # must be 1 for decode test + MAX_IN_KV_LEN = 110 + + # Generate random sequence lengths + num_qo_heads = num_kv_heads * head_grp_size + q_lens, in_kv_lens, seq_lens = generate_seq_lens( + batch_size, MAX_Q_LEN, MAX_IN_KV_LEN + ) + + # Create query tensor and related data + q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype) + + # Create KV cache and related data + kv_cache, k_scale, v_scale, ref_kv_cache = create_kv_cache( + batch_size, + seq_lens, + page_size, + num_kv_heads, + head_dim, + kv_dtype, + "bf16" if q_dtype == "fp8" else q_dtype, + ) + page_table, all_page_ids, page_per_seq = create_page_table( + batch_size, seq_lens, page_size + ) + kv_indptr = generate_cumsum_lens(page_per_seq) + kv_last_page_len = get_last_page_len(seq_lens, page_size) + + # Create output tensor and related data + create_out_tensor = flip_coin( + batch_size, page_size, num_kv_heads, head_grp_size, o_dtype + ) + out, o_scale, o_sf_scale, o_sf_vec_size = create_output( + q, o_dtype, create_out_tensor + ) + + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.zeros( + 128 * 1024 * 1024, dtype=torch.int8, device=GPU_DEVICE + ) + workspace_buffer = global_workspace_buffer + + # Run reference wrapper + wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True + ) + plan_params = { + "indptr": kv_indptr, + "indices": all_page_ids, + "last_page_len": kv_last_page_len.to(GPU_DEVICE), + "num_qo_heads": num_qo_heads, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "page_size": page_size, + "pos_encoding_mode": "NONE", + "kv_data_type": ref_kv_cache.dtype, + "q_data_type": ref_q.dtype, + "window_left": window_left, + } + wrapper_ref.plan(**plan_params) + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + + # Run trtllm-gen function call + sm_scale = float(1.0 / (head_dim**0.5)) + + output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( + q.contiguous(), + kv_cache, + workspace_buffer, + page_table, + seq_lens.to(GPU_DEVICE), + torch.max(seq_lens).item(), + q_scale * k_scale * sm_scale, # bmm1_scale + v_scale / o_scale, # bmm2_scale + window_left, # window_left + out=out, + out_dtype=DTYPE_MAP[o_dtype], + o_sf_scale=o_sf_scale, + o_sf_vec_size=o_sf_vec_size, + ) + + if o_dtype == "nvfp4": + output, output_ref = unpack_compare_nvfp4( + output, output_ref, o_sf_scale, o_sf_vec_size + ) + assert o_scale == 1.0 + rtol, atol = 3e-1, 1e0 + elif o_dtype == "fp8": + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 + + # convert to float32 for fp8 is not supported by assert_close + torch.testing.assert_close( + output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol + ) + + if o_dtype != "nvfp4": # wrapper api does not support fp4 output yet. + # test wrapper with trtllm-gen backend + wrapper_trtllm_gen = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, backend="trtllm-gen" + ) + plan_params["q_data_type"] = q.dtype + plan_params["kv_data_type"] = kv_cache.dtype + wrapper_trtllm_gen.plan(**plan_params) + output_wrapper = wrapper_trtllm_gen.run( + q.contiguous(), + kv_cache, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale / o_scale, + ) + # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. + if v_scale == o_scale == 1.0: + assert (output_wrapper == output).all() + else: + torch.testing.assert_close( + output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 + ) diff --git a/tests/test_trtllm_gen_context.py b/tests/test_trtllm_gen_context.py deleted file mode 100644 index 7c32a093a..000000000 --- a/tests/test_trtllm_gen_context.py +++ /dev/null @@ -1,511 +0,0 @@ -import math - -import pytest -import torch -from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant - -import flashinfer -from flashinfer.utils import FP4Tensor - -global_workspace_buffer = None - - -def flip_coin(*args, **kwargs): - # Use any test parameters to deterministically decide branch - # This makes test configurations go through different paths - param_tuple = args + tuple(sorted(kwargs.items())) - hash_value = hash(param_tuple) - return (hash_value % 2) == 0 - - -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax * 0.1 - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() - - -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND -@pytest.mark.parametrize("batch_size", [4, 8, 128]) -@pytest.mark.parametrize("kv_len", [512, 2048]) -@pytest.mark.parametrize("qo_len", [32, 16, 128, 512]) -@pytest.mark.parametrize("num_qo_heads", [4, 32]) -@pytest.mark.parametrize("head_dim", [128]) -@pytest.mark.parametrize("page_size", [16, 32, 64]) -@pytest.mark.parametrize("num_kv_heads", [4]) -@pytest.mark.parametrize("q_dtype", ["half", "bf16", "fp8"]) -@pytest.mark.parametrize("logits_soft_cap", [0.0]) -@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left -def test_trtllm_batch_context_wrapper( - kv_layout, - batch_size, - qo_len, - kv_len, - num_qo_heads, - head_dim, - page_size, - num_kv_heads, - q_dtype, - logits_soft_cap, - window_left, -): - seed = 0 - torch.manual_seed(seed) - device = "cuda:0" - - dtype_map = { - "half": torch.float16, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "nvfp4": "nvfp4", - } - - if q_dtype == "fp8": - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device=device, - dtype=torch.bfloat16, - ) - q, q_scale = to_float8(q) - # Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead. - ref_q = q.bfloat16() * q_scale - else: - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device=device, - dtype=dtype_map[q_dtype], - ) - q_scale = 1.0 - ref_q = q - - q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len - num_pages_per_seq = (kv_len + page_size - 1) // page_size - total_num_pages = num_pages_per_seq * batch_size - if kv_layout == "HND": - kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] - else: - kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] - kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") - kv_data = kv_data_fp32.to(dtype_map[q_dtype]) - ref_kv_data = kv_data.bfloat16() if q_dtype == "fp8" else kv_data - kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq - kv_indices_cpu = torch.arange(0, total_num_pages).int() - kv_last_page_len_cpu = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 - ) - global global_workspace_buffer - if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( - 256 * 1024 * 1024, dtype=torch.int8, device="cuda:0" - ) - workspace_buffer = global_workspace_buffer - - # reference - q_indptr_gpu = q_indptr_cpu.to(device) - kv_indptr_gpu = kv_indptr_cpu.to(device) - kv_indices_gpu = kv_indices_cpu.to(device) - kv_last_page_len_gpu = kv_last_page_len_cpu.to(device) - wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout - ) - wrapper.plan( - q_indptr_gpu, - kv_indptr_gpu, - kv_indices_gpu, - kv_last_page_len_gpu, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=logits_soft_cap, - q_data_type=ref_q.dtype, - window_left=window_left, - ) - reference_output = wrapper.run(ref_q, ref_kv_data) - reference_kv_cache = kv_data.clone() - - # trtllm-gen - wrapper2 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, backend="trtllm-gen" - ) - wrapper2.plan( - q_indptr_gpu, - kv_indptr_gpu, - kv_indices_gpu, - kv_last_page_len_gpu, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=logits_soft_cap, - q_data_type=dtype_map[q_dtype], - window_left=window_left, - ) - output = wrapper2.run(q, kv_data, q_scale=q_scale) - rmse = torch.sqrt(torch.mean((output.float() - reference_output.float()) ** 2)) - assert rmse.item() < (1e-2 if q_dtype == "fp8" else 1e-3) - - if q_dtype == "fp8": - rtol, atol = 8e-2, 8e-2 - else: - rtol, atol = 1e-2, 1e-2 - - torch.testing.assert_close( - output.float(), reference_output.float(), rtol=rtol, atol=atol - ) - torch.testing.assert_close( - reference_kv_cache.float(), kv_data.float(), rtol=rtol, atol=atol - ) - - # Test trtllm_batch_context_with_kv_cache function - seq_lens = flashinfer.page.get_seq_lens( - kv_indptr_cpu, kv_last_page_len_cpu, page_size - ).to(device) - - # Build block_tables using existing kv_indices_gpu - blocks_per_seq = [ - (seq_len + page_size - 1) // page_size for seq_len in seq_lens.cpu() - ] - max_num_blocks_per_seq = max(blocks_per_seq) - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), - dtype=torch.int32, - device=device, - ) - block_id = kv_indptr_cpu[0] - for i in range(batch_size): - num_blocks_needed = blocks_per_seq[i] - block_tables[i, :num_blocks_needed] = kv_indices_gpu[ - block_id : block_id + num_blocks_needed - ] - block_id += num_blocks_needed - - # Call trtllm_batch_context_with_kv_cache - direct_output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, - kv_cache=kv_data, - workspace_buffer=workspace_buffer, - block_tables=block_tables, - seq_lens=seq_lens, - max_q_len=qo_len, - max_kv_len=kv_len, - bmm1_scale=q_scale / math.sqrt(head_dim), - bmm2_scale=1, - batch_size=batch_size, - cum_seq_lens_q=q_indptr_gpu, - cum_seq_lens_kv=kv_indptr_gpu, - window_left=window_left, - ) - - # Compare direct function output with wrapper output - assert (direct_output == output).all() - - -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND -@pytest.mark.parametrize("batch_size", [4, 128, 256]) -@pytest.mark.parametrize("page_size", [16, 32, 64]) -@pytest.mark.parametrize("num_kv_heads", [2, 4]) -@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) -@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left -@pytest.mark.parametrize( - "q_dtype,kv_cache_dtype,o_dtype", - [ - ("half", "half", "half"), - ("bf16", "bf16", "bf16"), - ("fp8", "fp8", "fp8"), - ("fp8", "fp8", "nvfp4"), - ], -) -def test_trtllm_batch_prefill( - kv_layout, - batch_size, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_cache_dtype, -): - # Set up test parameters - seed = 0 - torch.manual_seed(seed) - device = "cuda:0" - head_dim = 128 - MAX_Q_LEN = 512 - MAX_IN_KV_LEN = 2048 - - dtype_map = { - "half": torch.float16, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "nvfp4": "nvfp4", - } - - # Sequence lengths and block tables - q_lens = torch.randint(1, MAX_Q_LEN, (batch_size,), dtype=torch.int32) - q_lens[-1] = MAX_Q_LEN - max_q_len = torch.max(q_lens).item() - q_lens_tensor = q_lens.to(device) - num_qo_heads = num_kv_heads * head_grp_size - - q = torch.randn( - torch.sum(q_lens).item(), - num_qo_heads, - head_dim, - dtype=torch.bfloat16 if q_dtype == "fp8" else dtype_map[q_dtype], - device=device, - ) - if q_dtype == "fp8": - q, q_scale = to_float8(q) - # Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead. - ref_q = q.bfloat16() * q_scale - else: - q_scale = 1.0 - ref_q = q - - in_kv_lens = torch.randint(0, MAX_IN_KV_LEN, (batch_size,), dtype=torch.int) - in_kv_lens[-1] = MAX_IN_KV_LEN - seq_lens = in_kv_lens + q_lens - seq_lens_gpu = seq_lens.to(device) - max_seq_len = torch.max(seq_lens).item() - - blocks_per_seq = (seq_lens + page_size - 1) // page_size - max_num_blocks_per_seq = torch.max(blocks_per_seq).item() - - # Generate random but unique block IDs for all sequences - total_blocks_needed = torch.sum(blocks_per_seq).item() - all_block_ids = torch.randperm( - total_blocks_needed, dtype=torch.int32, device=device - ) # Random permutation - - # Generate unique block IDs for all sequences - block_id = 0 - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device - ) - - # Populate block tables and track block assignments - block_id = 0 - for i in range(batch_size): - num_blocks_needed = blocks_per_seq[i] - block_tables[i, :num_blocks_needed] = all_block_ids[ - block_id : block_id + num_blocks_needed - ] - block_id += num_blocks_needed - - # Create separate K and V caches - num_tokens = max_seq_len * batch_size - num_blocks = (num_tokens + page_size - 1) // page_size - - kv_dtype = dtype_map[q_dtype] if q_dtype != "fp8" else torch.bfloat16 - k_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=kv_dtype, device=device - ) - v_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=kv_dtype, device=device - ) - # Convert K and V separately to fp8 if needed - if kv_cache_dtype.startswith("fp8"): - k_cache, k_scale = to_float8(k_cache) - v_cache, v_scale = to_float8(v_cache) - # use high precision for reference kv_cache to avoid precision/functional issue - ref_kv_type = torch.bfloat16 if q_dtype == "fp8" else dtype_map[q_dtype] - ref_kv_cache = torch.stack( - [k_cache.to(ref_kv_type) * k_scale, v_cache.to(ref_kv_type) * v_scale], - dim=1, - ) - else: - k_scale = v_scale = 1.0 - ref_kv_cache = torch.stack([k_cache, v_cache], dim=1) - - # Combine K and V into interleaved format for the API - kv_cache = torch.stack( - [k_cache, v_cache], dim=1 - ) # Shape: (num_blocks, 2, num_kv_heads, page_size, head_dim) - - if o_dtype == "fp8": - o_scale = torch.rand(1).item() * 0.5 + 0.5 # Scale range: 0.5 ~ 1.0 - else: - o_scale = 1.0 - o_sf_scale = ( - 300 if o_dtype == "nvfp4" else None - ) # choose a value to make error smaller by testing. - o_sf_vec_size = 16 if o_dtype == "nvfp4" else None - sm_scale = float(1.0 / (head_dim**0.5)) - - global global_workspace_buffer - if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( - 128 * 1024 * 1024, dtype=torch.int8, device="cuda:0" - ) - workspace_buffer = global_workspace_buffer - - q_indptr = torch.cat( - [ - torch.tensor([0], dtype=torch.int32, device=device), - torch.cumsum(q_lens_tensor, dim=0, dtype=torch.int32), - ] - ) - kv_indptr = torch.cat( - [ - torch.tensor([0], dtype=torch.int32, device=device), - torch.cumsum(blocks_per_seq.to(device), dim=0, dtype=torch.int32), - ] - ) - - if flip_coin(batch_size, page_size, num_kv_heads, head_grp_size, o_dtype): - if o_dtype == "nvfp4": - fp4_out_shape = q.shape[:-1] + (math.ceil(q.shape[-1] / 2),) - - fp4_out_scale_shape = ( - math.ceil(q.shape[0] / 128) * 128, - math.ceil(q.shape[1] * q.shape[2] / o_sf_vec_size / 4) * 4, - ) - - out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=q.device - ) - extra_size = fp4_out_scale_shape[0] - q.shape[0] - o_sf_start_index = ( - torch.randint(0, extra_size, (1,)).item() if extra_size > 0 else 0 - ) - out_data = torch.empty(fp4_out_shape, dtype=torch.uint8, device=q.device) - out = FP4Tensor(out_data, out_scale_factor, o_sf_start_index) - else: - out = torch.empty_like(q, dtype=dtype_map[o_dtype]) - else: - out = None - - output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( - q.contiguous(), - kv_cache, - workspace_buffer, - block_tables, - seq_lens_gpu, - max_q_len, - max_seq_len, - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale - batch_size, - q_indptr, - kv_indptr, - window_left, # window_left - out=out, - out_dtype=dtype_map[o_dtype], - o_sf_scale=o_sf_scale, - o_sf_vec_size=o_sf_vec_size, - ) - - # Handle different return types based on out_dtype - if o_dtype == "nvfp4": - out_scale_factor = output.scale # FP4Tensor.scale - o_sf_start_index = output.scale_start_index - output = output.data # FP4Tensor.data - else: - out_scale_factor = None - - wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout - ) - - # Calculate last page lengths - kv_last_page_len = seq_lens_gpu % page_size - kv_last_page_len[kv_last_page_len == 0] = page_size - logits_soft_cap = 0.0 - - wrapper.plan( - q_indptr, - kv_indptr, - all_block_ids, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=logits_soft_cap, - q_data_type=ref_q.dtype, - window_left=window_left, - ) - output_ref = wrapper.run(ref_q, ref_kv_cache) - - if q_dtype == "fp8" and o_dtype == "nvfp4": - rtol, atol = 4e-1, 1e0 - elif q_dtype == "fp8" and o_dtype == "fp8": - rtol, atol = 5e-2, 7e-2 - else: - rtol, atol = 1e-2, 1e-2 - - if o_dtype == "nvfp4": - output = cast_from_fp4(output) - output_ref, out_scale_factor_ref = ref_fp4_quant(output_ref, o_sf_scale, 16) - out_scale_factor = recover_swizzled_scales( - out_scale_factor, - output.shape[0], - output.shape[1] * output.shape[2], - 16, - o_sf_start_index, - ) - - torch.testing.assert_close( - out_scale_factor.float().reshape(out_scale_factor_ref.shape), - out_scale_factor_ref.float(), - rtol=2e-1, - atol=2e-1, - ) - rmse = torch.sqrt( - torch.mean((output.float() * o_scale - output_ref.float()) ** 2) - ) - assert rmse.item() < 0.3 - - # convert to float32 for fp8 is not supported by assert_close - torch.testing.assert_close( - output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol - ) - - if o_dtype != "nvfp4": # wrapper api does not support fp4 output yet. - # test wrapper with trtllm-gen backend - wrapper2 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, backend="trtllm-gen" - ) - wrapper2.plan( - q_indptr, - kv_indptr, - all_block_ids, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=logits_soft_cap, - q_data_type=q.dtype, - window_left=window_left, - ) - output2 = wrapper2.run( - q.contiguous(), - kv_cache, - q_scale=q_scale, - k_scale=k_scale, - v_scale=v_scale / o_scale, - ) - # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. - if v_scale == o_scale == 1.0: - assert (output2 == output).all() - else: - torch.testing.assert_close( - output.float(), output2.float(), rtol=1e-1, atol=1e-1 - ) diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py deleted file mode 100644 index 3d002ffce..000000000 --- a/tests/test_trtllm_gen_decode.py +++ /dev/null @@ -1,597 +0,0 @@ -import math - -import pytest -import torch -import torch.nn.functional as F -from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant - -import flashinfer -from flashinfer.utils import FP4Tensor - -global_workspace_buffer = None - - -def flip_coin(*args, **kwargs): - # Use any test parameters to deterministically decide branch - # This makes test configurations go through different paths - param_tuple = args + tuple(sorted(kwargs.items())) - hash_value = hash(param_tuple) - return (hash_value % 2) == 0 - - -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax * 0.1 - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() - - -def scaled_dot_product(q, k, v, mask=None): - d_k = q.size()[-1] - attn_logits = torch.matmul(q, k.transpose(-2, -1)) - attn_logits = attn_logits / math.sqrt(d_k) - if mask is not None: - attn_logits = attn_logits.masked_fill(mask == 0, -9e15) - attention = F.softmax(attn_logits, dim=-1) - values = torch.matmul(attention, v) - return values - - -def reference_paged_attention( - q: torch.Tensor, # [batch_size, num_q_heads, head_dim] - kv_cache: torch.Tensor, # [num_blocks, 2, num_kv_heads, page_size, head_dim] - block_tables: torch.Tensor, # [batch_size, max_blocks_per_seq] - seq_lens: torch.Tensor, # [batch_size] - page_size: int, - scale: float, - num_kv_heads: int, - head_dim: int, -): - batch_size, num_q_heads, _ = q.shape - device = q.device - dtype = q.dtype - head_grp_size = num_q_heads // num_kv_heads - - # Initialize output tensor - output = torch.zeros_like(q) - - for b in range(batch_size): - seq_len = seq_lens[b].item() - num_blocks = (seq_len + page_size - 1) // page_size - - # Get the blocks for this sequence - blocks = block_tables[b, :num_blocks] - - # Initialize K and V for this sequence - k_seq = torch.zeros( - (num_kv_heads, seq_len, head_dim), device=device, dtype=dtype - ) - v_seq = torch.zeros( - (num_kv_heads, seq_len, head_dim), device=device, dtype=dtype - ) - - # Gather K and V from kv_cache - current_pos = 0 - for block_id in blocks: - # Calculate how many tokens we can take from this block - remaining_tokens = seq_len - current_pos - tokens_to_take = min(page_size, remaining_tokens) - - if tokens_to_take <= 0: - break - - # Get K and V from the block - k_block = kv_cache[ - block_id, 0, :, :tokens_to_take, : - ] # [num_kv_heads, tokens_to_take, head_dim] - v_block = kv_cache[ - block_id, 1, :, :tokens_to_take, : - ] # [num_kv_heads, tokens_to_take, head_dim] - - # Store in the sequence tensor - k_seq[:, current_pos : current_pos + tokens_to_take, :] = k_block - v_seq[:, current_pos : current_pos + tokens_to_take, :] = v_block - - current_pos += tokens_to_take - - q_b = q[b].unsqueeze(1) - - k_seq = torch.repeat_interleave(k_seq, head_grp_size, dim=0) - v_seq = torch.repeat_interleave(v_seq, head_grp_size, dim=0) - output[b] = scaled_dot_product( - q_b.unsqueeze(0), k_seq.unsqueeze(0), v_seq.unsqueeze(0) - ).squeeze() - - return output - - -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND -@pytest.mark.parametrize("batch_size", [4, 128, 256]) -@pytest.mark.parametrize("page_size", [16, 32, 64]) -@pytest.mark.parametrize("num_kv_heads", [2, 4]) -@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) -@pytest.mark.parametrize("window_left", [-1, 127]) -@pytest.mark.parametrize( - "q_dtype,kv_cache_dtype,o_dtype", - [ - ("half", "half", "half"), - ("half", "fp8", "half"), - ("bf16", "bf16", "bf16"), - ("bf16", "fp8", "bf16"), - ("fp8", "fp8", "fp8"), - ("fp8", "fp8", "nvfp4"), - ], -) -def test_trtllm_batch_decode_fmha( - kv_layout, - batch_size, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_cache_dtype, -): - # Set up test parameters - seed = 0 - torch.manual_seed(seed) - device = "cuda:0" - head_dim = 128 - MAX_SEQ_LEN = 110 - - dtype_map = { - "half": torch.float16, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "nvfp4": "nvfp4", - } - - # Sequence lengths and block tables - num_qo_heads = num_kv_heads * head_grp_size - - q = torch.randn( - batch_size, - num_qo_heads, - head_dim, - dtype=torch.bfloat16 if q_dtype == "fp8" else dtype_map[q_dtype], - device=device, - ) - if q_dtype == "fp8": - q, q_scale = to_float8(q) - # Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead. - ref_q = q.bfloat16() * q_scale - else: - q_scale = 1.0 - ref_q = q - - seq_lens = torch.randint(1, MAX_SEQ_LEN, (batch_size,), dtype=torch.int32) - seq_lens[-1] = MAX_SEQ_LEN - seq_lens_gpu = seq_lens.to(device) - max_seq_len = torch.max(seq_lens).item() - - blocks_per_seq = (seq_lens + page_size - 1) // page_size - max_num_blocks_per_seq = torch.max(blocks_per_seq).item() - - # Generate random but unique block IDs for all sequences - total_blocks_needed = torch.sum(blocks_per_seq).item() - all_block_ids = torch.randperm( - total_blocks_needed, dtype=torch.int32, device=device - ) # Random permutation - - # Generate unique block IDs for all sequences - block_id = 0 - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device - ) - - # Populate block tables and track block assignments - block_id = 0 - for i in range(batch_size): - num_blocks_needed = blocks_per_seq[i] - block_tables[i, :num_blocks_needed] = all_block_ids[ - block_id : block_id + num_blocks_needed - ] - block_id += num_blocks_needed - - # Create separate K and V caches - num_tokens = max_seq_len * batch_size - num_blocks = (num_tokens + page_size - 1) // page_size - - kv_dtype = dtype_map[q_dtype] if q_dtype != "fp8" else torch.bfloat16 - k_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=kv_dtype, device=device - ) - v_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=kv_dtype, device=device - ) - # Convert K and V separately to fp8 if needed - if kv_cache_dtype.startswith("fp8"): - k_cache, k_scale = to_float8(k_cache) - v_cache, v_scale = to_float8(v_cache) - # use high precision for reference kv_cache to avoid precision/functional issue - ref_kv_type = torch.bfloat16 if q_dtype == "fp8" else dtype_map[q_dtype] - ref_kv_cache = torch.stack( - [k_cache.to(ref_kv_type) * k_scale, v_cache.to(ref_kv_type) * v_scale], - dim=1, - ) - else: - k_scale = v_scale = 1.0 - ref_kv_cache = torch.stack([k_cache, v_cache], dim=1) - - # Combine K and V into interleaved format for the API - kv_cache = torch.stack( - [k_cache, v_cache], dim=1 - ) # Shape: (num_blocks, 2, num_kv_heads, page_size, head_dim) - - if o_dtype == "fp8": - o_scale = torch.rand(1).item() * 0.5 + 0.5 # Scale range: 0.5 ~ 1.0 - else: - o_scale = 1.0 - o_sf_scale = ( - 300 if o_dtype == "nvfp4" else None - ) # choose a value to make error smaller by testing. - o_sf_vec_size = 16 if o_dtype == "nvfp4" else None - - sm_scale = float(1.0 / (head_dim**0.5)) - - global global_workspace_buffer - if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( - 128 * 1024 * 1024, dtype=torch.int8, device="cuda:0" - ) - workspace_buffer = global_workspace_buffer - - # Compute kv_indptr as cumulative sum of blocks per sequence - kv_indptr = torch.cat( - [ - torch.tensor([0], dtype=torch.int32, device=device), - torch.cumsum(blocks_per_seq.to(device), dim=0, dtype=torch.int32), - ] - ) - - if flip_coin(batch_size, page_size, num_kv_heads, head_grp_size, o_dtype): - if o_dtype == "nvfp4": - fp4_out_shape = q.shape[:-1] + (math.ceil(q.shape[-1] / 2),) - - fp4_out_scale_shape = ( - math.ceil(q.shape[0] / 128) * 128, - math.ceil(q.shape[1] * q.shape[2] / o_sf_vec_size / 4) * 4, - ) - - out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=q.device - ) - extra_size = fp4_out_scale_shape[0] - q.shape[0] - o_sf_start_index = ( - torch.randint(0, extra_size, (1,)).item() if extra_size > 0 else 0 - ) - out_data = torch.empty(fp4_out_shape, dtype=torch.uint8, device=q.device) - out = FP4Tensor(out_data, out_scale_factor, o_sf_start_index) - else: - out = torch.empty_like(q, dtype=dtype_map[o_dtype]) - else: - out = None - - output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( - q.contiguous(), - kv_cache, - workspace_buffer, - block_tables, - seq_lens_gpu, - max_seq_len, - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale - window_left, # window_left - out=out, - out_dtype=dtype_map[o_dtype], - o_sf_scale=o_sf_scale, - o_sf_vec_size=o_sf_vec_size, - ) - - # Handle different return types based on out_dtype - if o_dtype == "nvfp4": - out_scale_factor = output.scale # FP4Tensor.scale - o_sf_start_index = output.scale_start_index - output = output.data # FP4Tensor.data - else: - out_scale_factor = None - - output = output.squeeze(1) - - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True - ) - blocks_per_seq = (seq_lens_gpu + page_size - 1) // page_size - - # Calculate last page lengths - kv_last_page_len = seq_lens_gpu % page_size - kv_last_page_len[kv_last_page_len == 0] = page_size - - wrapper.plan( - kv_indptr, - all_block_ids, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode="NONE", - window_left=window_left, - data_type=ref_kv_cache.dtype, - q_data_type=ref_q.dtype, - ) - - output_ref = wrapper.run(ref_q, ref_kv_cache) - - if q_dtype == "fp8" and o_dtype == "nvfp4": - rtol, atol = 3e-1, 1e0 - elif q_dtype == "fp8" and o_dtype == "fp8": - rtol, atol = 5e-2, 7e-2 - else: - rtol, atol = 1e-2, 5e-2 - - if o_dtype == "nvfp4": - output = cast_from_fp4(output) - output_ref, out_scale_factor_ref = ref_fp4_quant(output_ref, o_sf_scale, 16) - out_scale_factor = recover_swizzled_scales( - out_scale_factor, - output.shape[0], - output.shape[1] * output.shape[2], - 16, - o_sf_start_index, - ) - - torch.testing.assert_close( - out_scale_factor.float().reshape(out_scale_factor_ref.shape), - out_scale_factor_ref.float(), - rtol=2e-1, - atol=2e-1, - ) - rmse = torch.sqrt( - torch.mean((output.float() * o_scale - output_ref.float()) ** 2) - ) - assert rmse.item() < 0.3 - # convert to float32 for fp8 is not supported by assert_close - torch.testing.assert_close( - output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol - ) - - if o_dtype != "nvfp4": # wrapper api does not support fp4 output yet. - # test wrapper with trtllm-gen backend - wrapper2 = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, backend="trtllm-gen" - ) - wrapper2.plan( - kv_indptr, - all_block_ids, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode="NONE", - data_type=kv_cache.dtype, - q_data_type=q.dtype, - window_left=window_left, - ) - output2 = wrapper2.run( - q.contiguous(), - kv_cache, - q_scale=q_scale, - k_scale=k_scale, - v_scale=v_scale / o_scale, - ) - # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. - if v_scale == o_scale == 1.0: - assert (output2 == output).all() - else: - torch.testing.assert_close( - output.float(), output2.float(), rtol=1e-1, atol=1e-1 - ) - - -@pytest.mark.parametrize( - "batch_size", - [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], -) -@pytest.mark.parametrize("scale", [1.0, 0.5]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) -@pytest.mark.parametrize("page_size", [32, 64]) -@pytest.mark.parametrize("q_len_per_request", [1, 2]) -@pytest.mark.parametrize("dynamic_scale", [False]) -def test_trtllm_batch_decode_mla( - batch_size: int, - scale: float, - dtype: torch.dtype, - page_size: int, - q_len_per_request: int, - dynamic_scale: bool, -): - if dynamic_scale and dtype != torch.float8_e4m3fn: - pytest.skip("Dynamic scale is not supported for non-fp8 dtype") - - torch.manual_seed(42) - device = "cuda:0" - - # Fixed max sequence length - MAX_SEQ_LEN = 1024 - - # Deepseek attention config (decode-MLA) - num_q_heads = 128 - qk_nope_head_dim = 128 - qk_rope_head_dim = 64 - kv_lora_rank = 512 - - # Initialize tensors - query = torch.randn( - batch_size, - q_len_per_request, - num_q_heads, - kv_lora_rank + qk_rope_head_dim, - device=device, - ).to(dtype) - - num_tokens = MAX_SEQ_LEN * batch_size - num_blocks = (num_tokens + page_size - 1) // page_size - - # Sequence lengths and block tables - seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)] - seq_lens[-1] = MAX_SEQ_LEN - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - - blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size - max_num_blocks_per_seq = blocks_per_seq.max().item() - - # Generate random but unique block IDs for all sequences - total_blocks_needed = sum(blocks_per_seq) - all_block_ids = torch.randperm( - total_blocks_needed, device=device - ) # Random permutation - - # Generate unique block IDs for all sequences - block_id = 0 - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device - ) - - # Populate block tables and track block assignments - block_id = 0 - for i in range(batch_size): - num_blocks_needed = blocks_per_seq[i] - block_tables[i, :num_blocks_needed] = all_block_ids[ - block_id : block_id + num_blocks_needed - ] - block_id += num_blocks_needed - - # Create interleaved KV cache - # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases - kv_cache = torch.randn( - size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device - ).to(dtype) - # (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim) - - # Allocate workspace buffer - # todo(Yingyi): calculate the actual size of workspace buffer - global global_workspace_buffer - if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( - 128 * 1024 * 1024, dtype=torch.int8, device="cuda:0" - ) - workspace_buffer = global_workspace_buffer - - bmm1_log2_scale_tensor = ( - torch.tensor( - [scale / ((128 + 64) ** 0.5 * math.log2(math.e))], - dtype=torch.float32, - device=device, - ) - if dynamic_scale - else None - ) - bmm2_scale_tensor = ( - torch.tensor([1.0], dtype=torch.float32, device=device) - if dynamic_scale - else None - ) - - # Run decode-MLA - output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( - query=query, - kv_cache=kv_cache.unsqueeze(1), - workspace_buffer=workspace_buffer, - qk_nope_head_dim=qk_nope_head_dim, - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - block_tables=block_tables, - seq_lens=seq_lens_tensor, - max_seq_len=max_seq_len, - bmm1_scale=scale / ((128 + 64) ** 0.5), - bmm2_scale=1.0, - bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, - bmm2_scale_tensor=bmm2_scale_tensor, - ) - - # Run reference attention and align output - sm_scale = scale / ( - (128 + 64) ** 0.5 - ) # use head dimension before matrix absorption - workspace_buffer_ref = torch.empty( - 128 * 1024 * 1024, dtype=torch.int8, device=device - ) - wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - workspace_buffer_ref, - backend="fa2", - ) - - if dtype == torch.float8_e4m3fn: - # convert query and kv_cache to bfloat16 - query = query.to(torch.bfloat16) - kv_cache = kv_cache.to(torch.bfloat16) - - q_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) - * q_len_per_request - ) - kv_indptr = torch.zeros_like(q_indptr) - kv_indptr[1:] = torch.cumsum(blocks_per_seq, dim=0) - kv_indices = all_block_ids.int() - - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - seq_lens_tensor, - num_q_heads, - kv_lora_rank, - qk_rope_head_dim, - page_size, - True, - sm_scale, - query.dtype, - kv_cache.dtype, - ) - q_nope = query[..., :kv_lora_rank].view( - batch_size * q_len_per_request, num_q_heads, kv_lora_rank - ) - q_pe = query[..., kv_lora_rank:].view( - batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim - ) - - # todo: fix kv_cache - ckv = kv_cache[..., :kv_lora_rank] - kpe = kv_cache[..., kv_lora_rank:] - - o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) - - # check is nan - assert not torch.isnan(o_ref).any(), "o_ref is nan" - assert not torch.isnan(output).any(), "output is nan" - - if dtype == torch.float8_e4m3fn: - try: - torch.testing.assert_close( - output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), - rtol=1e-1, - atol=1e-1, - ) # todo: do reference with normal attention? - except AssertionError as e: - print("output:", output) - print("o_ref:", o_ref) - raise e - else: - try: - torch.testing.assert_close( - output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), - rtol=1e-2, - atol=1e-2, - ) - except AssertionError as e: - print("output:", output) - print("o_ref:", o_ref) - raise e diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py new file mode 100644 index 000000000..e1dea39ed --- /dev/null +++ b/tests/test_trtllm_gen_mla.py @@ -0,0 +1,212 @@ +import math + +import pytest +import torch + +import flashinfer + +global_workspace_buffer = None + + +@pytest.mark.parametrize( + "batch_size", + [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], +) +@pytest.mark.parametrize("scale", [1.0, 0.5]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [32, 64]) +@pytest.mark.parametrize("q_len_per_request", [1, 2]) +@pytest.mark.parametrize("dynamic_scale", [False]) +def test_trtllm_batch_decode_mla( + batch_size: int, + scale: float, + dtype: torch.dtype, + page_size: int, + q_len_per_request: int, + dynamic_scale: bool, +): + if dynamic_scale and dtype != torch.float8_e4m3fn: + pytest.skip("Dynamic scale is not supported for non-fp8 dtype") + + torch.manual_seed(42) + device = "cuda:0" + + # Fixed max sequence length + MAX_SEQ_LEN = 1024 + + # Deepseek attention config (decode-MLA) + num_q_heads = 128 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + kv_lora_rank = 512 + + # Initialize tensors + query = torch.randn( + batch_size, + q_len_per_request, + num_q_heads, + kv_lora_rank + qk_rope_head_dim, + device=device, + ).to(dtype) + + num_tokens = MAX_SEQ_LEN * batch_size + num_blocks = (num_tokens + page_size - 1) // page_size + + # Sequence lengths and block tables + seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) + + blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size + max_num_blocks_per_seq = blocks_per_seq.max().item() + + # Generate random but unique block IDs for all sequences + total_blocks_needed = sum(blocks_per_seq) + all_block_ids = torch.randperm( + total_blocks_needed, device=device + ) # Random permutation + + # Generate unique block IDs for all sequences + block_id = 0 + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device + ) + + # Populate block tables and track block assignments + block_id = 0 + for i in range(batch_size): + num_blocks_needed = blocks_per_seq[i] + block_tables[i, :num_blocks_needed] = all_block_ids[ + block_id : block_id + num_blocks_needed + ] + block_id += num_blocks_needed + + # Create interleaved KV cache + # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases + kv_cache = torch.randn( + size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device + ).to(dtype) + # (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim) + + # Allocate workspace buffer + # todo(Yingyi): calculate the actual size of workspace buffer + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.zeros( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + workspace_buffer = global_workspace_buffer + + bmm1_log2_scale_tensor = ( + torch.tensor( + [scale / ((128 + 64) ** 0.5 * math.log2(math.e))], + dtype=torch.float32, + device=device, + ) + if dynamic_scale + else None + ) + bmm2_scale_tensor = ( + torch.tensor([1.0], dtype=torch.float32, device=device) + if dynamic_scale + else None + ) + + # Run decode-MLA + output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + bmm1_scale=scale / ((128 + 64) ** 0.5), + bmm2_scale=1.0, + bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, + ) + + # Run reference attention and align output + sm_scale = scale / ( + (128 + 64) ** 0.5 + ) # use head dimension before matrix absorption + workspace_buffer_ref = torch.empty( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + workspace_buffer_ref, + backend="fa2", + ) + + if dtype == torch.float8_e4m3fn: + # convert query and kv_cache to bfloat16 + query = query.to(torch.bfloat16) + kv_cache = kv_cache.to(torch.bfloat16) + + q_indptr = ( + torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) + * q_len_per_request + ) + kv_indptr = torch.zeros_like(q_indptr) + kv_indptr[1:] = torch.cumsum(blocks_per_seq, dim=0) + kv_indices = all_block_ids.int() + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + seq_lens_tensor, + num_q_heads, + kv_lora_rank, + qk_rope_head_dim, + page_size, + True, + sm_scale, + query.dtype, + kv_cache.dtype, + ) + q_nope = query[..., :kv_lora_rank].view( + batch_size * q_len_per_request, num_q_heads, kv_lora_rank + ) + q_pe = query[..., kv_lora_rank:].view( + batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim + ) + + # todo: fix kv_cache + ckv = kv_cache[..., :kv_lora_rank] + kpe = kv_cache[..., kv_lora_rank:] + + o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) + + # check is nan + assert not torch.isnan(o_ref).any(), "o_ref is nan" + assert not torch.isnan(output).any(), "output is nan" + + if dtype == torch.float8_e4m3fn: + try: + torch.testing.assert_close( + output, + o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + rtol=1e-1, + atol=1e-1, + ) # todo: do reference with normal attention? + except AssertionError as e: + print("output:", output) + print("o_ref:", o_ref) + raise e + else: + try: + torch.testing.assert_close( + output, + o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + rtol=1e-2, + atol=1e-2, + ) + except AssertionError as e: + print("output:", output) + print("o_ref:", o_ref) + raise e