Skip to content

Commit 43ec6c7

Browse files
authored
test: use .float() in in F.cosine_similarity() in bmm_fp8 test (flashinfer-ai#2266)
<!-- .github/pull_request_template.md --> ## 📌 Description saw some [test failures](https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/247866505) on Blackwell boards after flashinfer-ai#2261, all the failed assertions are related to the large value 10304. Use `.float()` to help reduce precision loss during `cosine_similarity` (`dot(x, y) / (||x|| * ||y||)`) check. ``` FAILED tests/gemm/test_bmm_fp8.py::test_bmm_fp8[True-cutlass-res_dtype1-mat2_dtype0-input_dtype0-256-10304-128-16] - AssertionError: assert tensor(0., device='cuda:0') > 0.99 2025-12-24T07:00:08.299846Z 01O FAILED tests/gemm/test_bmm_fp8.py::test_bmm_fp8[False-cudnn-res_dtype1-mat2_dtype0-input_dtype1-256-10304-128-16] - AssertionError: assert tensor(0., device='cuda:0') > 0.99 ... # the failure occurs for all backend (cutlass, cudnn, etc) ``` cc: @zihaoye @bkryu ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Improved test accuracy by ensuring tensor comparisons use floating-point precision for cosine similarity calculations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent fe093d6 commit 43ec6c7

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/gemm/test_bmm_fp8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_t
5151
backend=backend,
5252
)
5353

54-
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
54+
cos_sim = F.cosine_similarity(
55+
reference.reshape(-1).float(), res.reshape(-1).float(), dim=0
56+
)
5557
assert cos_sim > 0.99
5658

5759

0 commit comments

Comments
 (0)