Skip to content

Commit c691768

Browse files
authored
raise error for group_gemm_fp8_nt_groupwise then num_groups > 1 on sm120/121 (#1862)
## 📌 Description - Raise a RuntimeError for group_gemm_fp8_nt_groupwise when num_groups > 1 on SM120/121. - Skip the related tests. - Tentatively rename tests/GEMM to tests/gemm for consistency with other components that use lowercase directory names across the codebase (correct me pls!!!). ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes @yzh119 @aleozlx @nvmbreughe @bkryu
1 parent 1595175 commit c691768

13 files changed

+13
-2
lines changed

flashinfer/gemm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,11 @@ def group_gemm_fp8_nt_groupwise(
23042304
assert out.dtype == out_dtype
23052305

23062306
if is_sm120a_supported(a.device) or is_sm121a_supported(a.device):
2307+
# it has correctness issues for num_groups > 1
2308+
if num_groups > 1:
2309+
raise RuntimeError(
2310+
"group_gemm_fp8_nt_groupwise has correctness issues for num_groups > 1 on SM120/121"
2311+
)
23072312
# SM120/121 doesn't use mma_sm parameter
23082313
get_gemm_sm120_module().group_gemm_fp8_nt_groupwise(
23092314
int_workspace_buffer,

scripts/task_jit_run_tests_part1.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ set -x
77

88
pip install -e . -v
99

10-
# pytest -s tests/GEMM/test_group_gemm.py
10+
# pytest -s tests/gemm/test_group_gemm.py
1111
pytest -s tests/attention/test_logits_cap.py
1212
pytest -s tests/attention/test_sliding_window.py
1313
pytest -s tests/attention/test_tensor_cores_decode.py

scripts/task_jit_run_tests_part4.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pip install -e . -v
99

1010
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # avoid memory fragmentation
1111
pytest -s tests/attention/test_deepseek_mla.py
12-
pytest -s tests/GEMM/test_group_gemm.py
12+
pytest -s tests/gemm/test_group_gemm.py
1313
pytest -s tests/attention/test_batch_prefill_kernels.py
1414
# NOTE(Zihao): need to fix tile size on KV dimension for head_dim=256 on small shared memory architecture (sm89)
1515
# pytest -s tests/attention/test_batch_attention.py
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

tests/GEMM/test_groupwise_scaled_gemm_fp8.py renamed to tests/gemm/test_groupwise_scaled_gemm_fp8.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ def test_fp8_groupwise_group_gemm(
146146
scale_major_mode,
147147
out_dtype,
148148
):
149+
if group_size > 1 and torch.cuda.get_device_capability()[0] in [
150+
12,
151+
]:
152+
pytest.skip(
153+
"group_gemm_fp8_nt_groupwise has correctness issues for num_groups > 1 on SM120/121"
154+
)
149155
torch.random.manual_seed(0)
150156
tile_size = 128
151157

File renamed without changes.

0 commit comments

Comments
 (0)