Skip to content

Commit 927a41e

Browse files
authored
feat: add mm_fp4 use cudnn backend (#1288)
support a/b input type e2m1, block quant type e4m3 with block size 16 output bfloat16 and fp16. <!-- .github/pull_request_template.md --> ## 📌 Description init add mm_fp4 use cudnn backend ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: Vincent Huang <[email protected]>
1 parent 8587c21 commit 927a41e

File tree

4 files changed

+449
-43
lines changed

4 files changed

+449
-43
lines changed

flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
)
6666
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
6767
from .gemm import bmm_fp8 as bmm_fp8
68+
from .gemm import mm_fp4 as mm_fp4
6869
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper
6970
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
7071
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm

flashinfer/fp4_quantization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@ def fp4_quantize(
253253
if sf_vec_size != 16 and sf_vec_size != 32:
254254
raise NotImplementedError("sf_vec_size can only be 16 or 32")
255255

256+
# for column major input, we need to transpose the input
257+
is_column_major = input.stride(-2) == 1
258+
if is_column_major:
259+
input = input.transpose(-2, -1)
260+
256261
assert input.shape[-1] % sf_vec_size == 0
257262
x_q, sf = get_fp4_quantization_sm100_module().fp4_quantize_sm100(
258263
input,
@@ -262,6 +267,10 @@ def fp4_quantize(
262267
is_sf_swizzled_layout,
263268
)
264269
sf = sf.reshape((-1, input.shape[-1] // sf_vec_size))
270+
if is_column_major:
271+
x_q = x_q.transpose(-2, -1)
272+
sf = sf.transpose(-2, -1)
273+
265274
return x_q, sf
266275

267276

0 commit comments

Comments
 (0)