Skip to content

Commit 40d3fea

Browse files
authored
tests: Add batch size 1 cases to test_trtllm_gen_attention.py that fail, marked xfail (#1897)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Trtllm-gen's attention kernels have been discovered to fail tests when batch size is 1. Current PR adds batch size 1 cases to: `test_trtllm_gen_prefill_deepseek`: that triggers an IMA with the newly added parameters ``` ## Running pytest ./tests/attention/test_trtllm_gen_attention.py::test_trtllm_gen_prefill_deepseek -v > default_generator.manual_seed(seed) E torch.AcceleratorError: CUDA error: an illegal memory access was encountered E CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. E For debugging consider passing CUDA_LAUNCH_BLOCKING=1 E Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/random.py:129: AcceleratorError ``` `test_trtllm_batch_decode`: that produces incorrect outputs with newly added parameters ``` ## Running pytest ./tests/attention/test_trtllm_gen_attention.py::test_trtllm_batch_decode -v > torch.testing.assert_close( output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1, ) E AssertionError: Tensor-likes are not close! E E Mismatched elements: 1480 / 8192 (18.1%) E Greatest absolute difference: 64.021484375 at index (0, 46, 106) (up to 0.1 allowed) E Greatest relative difference: 1.625 at index (0, 56, 109) (up to 0.1 allowed) ``` **These test cases have been marked as `pytest.xfail()`.** To avoid a combinatorial growth of test parameter combinations, these batch size 1 cases were defined as separate test functions. B200 status before PR: `2052 passed, 264 skipped in 177.80s (0:02:57)` B200 status after PR: `2052 passed, 264 skipped, 3 xfailed in 195.14s (0:03:15)` Status tracked in [Issue 1898](#1898) ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 674843f commit 40d3fea

File tree

1 file changed

+68
-2
lines changed

1 file changed

+68
-2
lines changed

tests/attention/test_trtllm_gen_attention.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def test_trtllm_batch_prefill(
564564
)
565565
@pytest.mark.parametrize("enable_pdl", [True, False, None])
566566
@pytest.mark.parametrize("enable_sink", [True, False])
567+
@pytest.mark.parametrize("max_in_kv_len", [110])
567568
def test_trtllm_batch_decode(
568569
kv_layout,
569570
batch_size,
@@ -577,6 +578,7 @@ def test_trtllm_batch_decode(
577578
kv_dtype,
578579
enable_pdl,
579580
enable_sink,
581+
max_in_kv_len,
580582
):
581583
compute_capability = get_compute_capability(torch.device(device="cuda"))
582584
if compute_capability[0] != 10:
@@ -589,12 +591,11 @@ def test_trtllm_batch_decode(
589591
# Set up test parameters
590592
torch.manual_seed(0)
591593
head_dim = 128
592-
MAX_IN_KV_LEN = 110
593594

594595
# Generate random sequence lengths
595596
num_qo_heads = num_kv_heads * head_grp_size
596597
q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode(
597-
batch_size, q_len_per_req, MAX_IN_KV_LEN
598+
batch_size, q_len_per_req, max_in_kv_len
598599
)
599600

600601
# Create query tensor and related data
@@ -805,6 +806,56 @@ def test_trtllm_batch_decode(
805806
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
806807

807808

809+
@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
810+
@pytest.mark.parametrize(
811+
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
812+
[
813+
(1, 1, 16, 8, 8),
814+
],
815+
)
816+
@pytest.mark.parametrize("window_left", [-1])
817+
@pytest.mark.parametrize(
818+
"q_dtype,kv_dtype,o_dtype",
819+
[
820+
("fp8", "fp8", "fp8"),
821+
],
822+
)
823+
@pytest.mark.parametrize("enable_pdl", [None])
824+
@pytest.mark.parametrize("enable_sink", [False])
825+
@pytest.mark.parametrize("max_in_kv_len", [8192])
826+
def test_trtllm_batch_decode_bs1(
827+
kv_layout,
828+
batch_size,
829+
q_len_per_req,
830+
page_size,
831+
num_kv_heads,
832+
head_grp_size,
833+
window_left,
834+
q_dtype,
835+
o_dtype,
836+
kv_dtype,
837+
enable_pdl,
838+
enable_sink,
839+
max_in_kv_len,
840+
):
841+
pytest.xfail("trtllm-gen decode gets incorrect output with bs1")
842+
test_trtllm_batch_decode(
843+
kv_layout,
844+
batch_size,
845+
q_len_per_req,
846+
page_size,
847+
num_kv_heads,
848+
head_grp_size,
849+
window_left,
850+
q_dtype,
851+
o_dtype,
852+
kv_dtype,
853+
enable_pdl,
854+
enable_sink,
855+
max_in_kv_len,
856+
)
857+
858+
808859
@pytest.mark.parametrize("batch_size", [4, 128, 256])
809860
@pytest.mark.parametrize("s_qo", [32, 64, 87])
810861
@pytest.mark.parametrize("s_kv", [32, 64, 87])
@@ -938,6 +989,21 @@ def test_trtllm_gen_prefill_deepseek(
938989
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
939990

940991

992+
@pytest.mark.parametrize("batch_size", [1])
993+
@pytest.mark.parametrize("s_qo", [1024])
994+
@pytest.mark.parametrize("s_kv", [1024])
995+
@pytest.mark.parametrize("num_kv_heads", [128])
996+
@pytest.mark.parametrize("head_grp_size", [1])
997+
@pytest.mark.parametrize("causal", [True, False])
998+
def test_trtllm_gen_prefill_deepseek_bs1(
999+
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
1000+
):
1001+
pytest.xfail("trtllm-gen prefill triggers an IMA with bs1")
1002+
test_trtllm_gen_prefill_deepseek(
1003+
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
1004+
)
1005+
1006+
9411007
if __name__ == "__main__":
9421008
test_trtllm_batch_prefill("HND", 128, 32, 2, 5, -1, "fp16", "fp16", "fp16", False)
9431009
test_trtllm_batch_decode("HND", 256, 3, 64, 4, 5, -1, "fp8", "fp8", "fp8", True)

0 commit comments

Comments
 (0)