You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
tests: Add batch size 1 cases to test_trtllm_gen_attention.py that fail, marked xfail (#1897)
<!-- .github/pull_request_template.md -->
## 📌 Description
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
Trtllm-gen's attention kernels have been discovered to fail tests when
batch size is 1.
Current PR adds batch size 1 cases to:
`test_trtllm_gen_prefill_deepseek`: that triggers an IMA with the newly
added parameters
```
## Running pytest ./tests/attention/test_trtllm_gen_attention.py::test_trtllm_gen_prefill_deepseek -v
> default_generator.manual_seed(seed)
E torch.AcceleratorError: CUDA error: an illegal memory access was encountered
E CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E For debugging consider passing CUDA_LAUNCH_BLOCKING=1
E Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/random.py:129: AcceleratorError
```
`test_trtllm_batch_decode`: that produces incorrect outputs with newly
added parameters
```
## Running pytest ./tests/attention/test_trtllm_gen_attention.py::test_trtllm_batch_decode -v
> torch.testing.assert_close(
output.float(),
output_wrapper.float(),
rtol=1e-1,
atol=1e-1,
)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 1480 / 8192 (18.1%)
E Greatest absolute difference: 64.021484375 at index (0, 46, 106) (up to 0.1 allowed)
E Greatest relative difference: 1.625 at index (0, 56, 109) (up to 0.1 allowed)
```
**These test cases have been marked as `pytest.xfail()`.** To avoid a
combinatorial growth of test parameter combinations, these batch size 1
cases were defined as separate test functions.
B200 status before PR: `2052 passed, 264 skipped in 177.80s (0:02:57)`
B200 status after PR: `2052 passed, 264 skipped, 3 xfailed in 195.14s
(0:03:15)`
Status tracked in [Issue
1898](#1898)
## 🔍 Related Issues
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
0 commit comments