Skip to content

Commit a003c02

Browse files
authored
Add test case for Qwen3N (#2532)
<!-- .github/pull_request_template.md --> ## 📌 Description Add test case for Qwen3N, and Qwen3.5 according to vllm-project/vllm#34131 <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Expanded test coverage by adding additional head-configuration cases across multiple test scenarios to improve reliability and catch more edge cases. * No changes to test logic or public interfaces; only parameterized inputs were extended. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent f4d10a7 commit a003c02

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

tests/gdn/test_prefill_delta_rule.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,16 @@ def _test_prefill_kernel(
144144
@pytest.mark.parametrize("head_size", [128])
145145
@pytest.mark.parametrize(
146146
"num_q_heads, num_k_heads, num_v_heads",
147-
[(1, 1, 1), (4, 1, 1), (3, 3, 3), (6, 2, 2), (1, 1, 2), (2, 2, 4)],
147+
[
148+
(1, 1, 1),
149+
(4, 1, 1),
150+
(3, 3, 3),
151+
(6, 2, 2),
152+
(1, 1, 2),
153+
(2, 2, 4),
154+
(16, 16, 32),
155+
(16, 16, 64),
156+
],
148157
)
149158
@pytest.mark.parametrize("seq_lens", [[64], [128], [256], [256, 256], [64, 128, 512]])
150159
@pytest.mark.parametrize("block_size", [64])
@@ -186,7 +195,16 @@ def test_prefill_kernel_basic(
186195
@pytest.mark.parametrize("head_size", [128])
187196
@pytest.mark.parametrize(
188197
"num_q_heads, num_k_heads, num_v_heads",
189-
[(1, 1, 1), (4, 1, 1), (3, 3, 3), (6, 2, 2), (1, 1, 2), (2, 2, 4)],
198+
[
199+
(1, 1, 1),
200+
(4, 1, 1),
201+
(3, 3, 3),
202+
(6, 2, 2),
203+
(1, 1, 2),
204+
(2, 2, 4),
205+
(16, 16, 32),
206+
(16, 16, 64),
207+
],
190208
)
191209
@pytest.mark.parametrize(
192210
"seq_lens",
@@ -390,7 +408,8 @@ def concat_varlen(t1, cu_seq_lens1, t2, cu_seq_lens2):
390408
@pytest.mark.parametrize("scale", [1.0, "auto"])
391409
@pytest.mark.parametrize("head_size", [128])
392410
@pytest.mark.parametrize(
393-
"num_q_heads, num_k_heads, num_v_heads", [(6, 2, 2), (2, 2, 4)]
411+
"num_q_heads, num_k_heads, num_v_heads",
412+
[(6, 2, 2), (2, 2, 4), (16, 16, 32), (16, 16, 64)],
394413
)
395414
@pytest.mark.parametrize(
396415
"seq_lens1, seq_lens2",

0 commit comments

Comments
 (0)