Skip to content

Commit 22a62ea

Browse files
authored
bugfix: fixed cutlass fused moe usage of FP4QuantizationSFLayout::SWIZZLED (#1371)
<!-- .github/pull_request_template.md --> ## 📌 Description cutlass fused moe modules are broken after #1355 because the structure of `FP4QuantizationSFLayout` has changed. This PR fixes the issue. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @wenscarl @ttyio
1 parent dbb438c commit 22a62ea

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ __device__ auto quantizePackedFPXValue(
983983
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
984984
NumThreadsPerSF, VecSize>(
985985
std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx,
986-
std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED);
986+
std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED_128x4);
987987

988988
// Do the conversion and set the output and scaling factor
989989
auto func = [&]() {
@@ -1023,15 +1023,15 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id,
10231023
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
10241024
NumThreadsPerSF, VecSize>(
10251025
std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx,
1026-
std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED);
1026+
std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED_128x4);
10271027
if (sf_out) {
10281028
if (input_sf) {
10291029
auto const sf_in =
10301030
cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
10311031
NumThreadsPerSF, VecSize>(
10321032
std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
10331033
num_cols, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
1034-
FP4QuantizationSFLayout::SWIZZLED);
1034+
FP4QuantizationSFLayout::SWIZZLED_128x4);
10351035
*sf_out = *sf_in;
10361036
} else {
10371037
*sf_out = 0x00;

0 commit comments

Comments
 (0)