Skip to content

Commit dbb438c

Browse files
authored
hotfix: update mxfp4 groupwise-scaled gemm unittests (#1359)
<!-- .github/pull_request_template.md --> ## 📌 Description `fp4_swizzle_blockscale_sm100` function was removed in #1214 and thus breaking tests/test_groupwise_scaled_gemm_mxfp4.py unittest, this PR fixes the issue. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @ttyio @azhurkevich
1 parent 7e1c830 commit dbb438c

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

tests/test_groupwise_scaled_gemm_mxfp4.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,8 @@ def swizzle_blockscale(
6464
_pad_scale_factors(unswizzled_sf[i], m, n, sf_vec_size) for i in range(b)
6565
]
6666
padded_input_sf = torch.stack(padded_input_sf_chunked)
67-
out = torch.empty_like(padded_input_sf)
68-
get_fp4_quantization_sm100_module().fp4_swizzle_blockscale_sm100(
69-
padded_input_sf.flatten(0, 1),
70-
out.flatten(0, 1),
71-
out.shape[0],
72-
out.shape[1],
73-
out.shape[2],
67+
out = get_fp4_quantization_sm100_module().nvfp4_block_scale_interleave_sm100(
68+
padded_input_sf
7469
)
7570
out = out.view(padded_input_sf.shape)
7671
return out

0 commit comments

Comments
 (0)