Skip to content

Commit cf1b2d2

Browse files
authored
bugfix: Fix trtllm moe launcher local_num_experts (#1398)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Signed-off-by: Shu Wang <[email protected]>
1 parent dd20f55 commit cf1b2d2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
10621062
hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel();
10631063
}
10641064
int weight_scale_vec_size =
1065-
(num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel();
1065+
(local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel();
10661066
TORCH_CHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32,
10671067
"unsupported weight_scale_vec_size.");
10681068
auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1;

0 commit comments

Comments
 (0)