Skip to content

Commit 4fe837f

Browse files
authored
Fix sink attention accuracy regression, add sink test and cleanup. (#1758)
<!-- .github/pull_request_template.md --> ## 📌 Description Update trtllm-gen cubin to fix accuracy regression about sink. Integrate sink test to attention test as the old sink test didn't test and catch the fp8 accuracy issue before. Clean up attention test code. Attention test case has inflated a lot as param num raise. Reduce by using pairwise combination instead of Cartesian product. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent c721fb7 commit 4fe837f

File tree

3 files changed

+206
-118
lines changed

3 files changed

+206
-118
lines changed

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
6969

7070

7171
class ArtifactPath:
72-
TRTLLM_GEN_FMHA: str = "538f8e38ace07f701f61e26b138b2b8c70ce9e8e/fmha/trtllm-gen/"
72+
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen/"
7373
TRTLLM_GEN_BMM: str = (
7474
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802/"
7575
)
@@ -82,7 +82,7 @@ class ArtifactPath:
8282

8383
class MetaInfoHash:
8484
TRTLLM_GEN_FMHA: str = (
85-
"71f06a8fc03d28cc94ee6fc180fb7e37256a9e1c30ab2a6c0bf20a2d97af3eff"
85+
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
8686
)
8787
TRTLLM_GEN_BMM: str = (
8888
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"

tests/test_attention_sink_blackwell.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import pytest
1919
import torch
2020
from sink_attention_reference import sink_attention_unified
21-
from conftest import assert_close_with_mismatch_tolerance
2221

2322
import flashinfer
2423
from flashinfer.utils import get_compute_capability
@@ -122,13 +121,7 @@ def test_blackwell_trtllm_gen_decode_attention_sink(
122121
else:
123122
raise ValueError(f"Unsupported dtype: {dtype}")
124123

125-
assert_close_with_mismatch_tolerance(
126-
o_ref,
127-
output,
128-
atol=atol,
129-
rtol=rtol,
130-
max_mismatched_elements=int(output.numel() * 0.01),
131-
)
124+
torch.testing.assert_close(o_ref, output, atol=atol, rtol=rtol)
132125

133126

134127
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])

0 commit comments

Comments
 (0)