Skip to content

Commit 0522a32

Browse files
authored
fix: fix the failed sampling unittest on 5090 (#1886)
<!-- .github/pull_request_template.md --> ## 📌 Description Applying softmax followed by top_k_renorm does not guarantee bitwise-identical results compared to top_k_mask followed by softmax. This may cause slight differences in subsequent top-p sampling. In this PR we relax the condition to up to a 1% mismatch rate. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 68826ac commit 0522a32

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tests/utils/test_sampling.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,18 @@ def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size
377377
filter_apply_order="top_k_first",
378378
generator=generator_probs,
379379
)
380-
assert torch.all(samples == samples_ref)
380+
381+
num_matches = (samples == samples_ref).sum().item()
382+
match_rate = num_matches / samples.numel()
383+
384+
# NOTE(Zihao): Applying softmax followed by top_k_renorm (softmax -> top_k_renorm)
385+
# does not guarantee bitwise-identical results compared to top_k_mask followed by softmax (top_k_mask -> softmax).
386+
# This may cause slight differences in subsequent top-p sampling.
387+
# We tolerate up to a 1% mismatch rate.
388+
assert match_rate >= 0.99, (
389+
f"Sample match rate {match_rate:.2%} is below threshold "
390+
f"({batch_size - num_matches}/{batch_size} mismatches, expected <=1%)"
391+
)
381392

382393

383394
@pytest.mark.parametrize("batch_size", [1, 99, 989])

0 commit comments

Comments
 (0)