Skip to content

Commit d3e9b44

Browse files
authored
bugfix: add check for empty MoE tactics and allow sm121 to use sm120 config (#1861)
## 📌 Description - Add safety check for empty tactics to provide clear error message - Allow SM121 devices to use SM120 kernel configurations - Mark test_moe_mxfp8_mxfp4 as xfail (kernel not yet implemented) for sm120/121 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc: @djmmoss @yzh119 @bkryu @cyx-6 @nv-yunzheq
1 parent c691768 commit d3e9b44

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
220220

221221
mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
222222
mAllProfiles = mKernelRunner->getTactics();
223+
TVM_FFI_ICHECK(!mAllProfiles.empty())
224+
<< "No valid tactics available for fused moe op with the requested input combination "
225+
"Activation: "
226+
<< DLDataTypeToString(mActivationDtype) << ", Weight: " << DLDataTypeToString(mWeightDtype)
227+
<< ", Output: " << DLDataTypeToString(mOutputDtype);
223228
}
224229

225230
void runMoe(Tensor output, Tensor input, Tensor token_selected_experts,

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,9 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
731731
// 110 below logging helps confirming the cutlass pipeline matches the device major version
732732
bool is_sm110 = inputs.gemm_config.sm_version == 100 && sm_ == 110;
733733
bool is_sm103 = inputs.gemm_config.sm_version == 100 && sm_ == 103;
734-
TLLM_CHECK_WITH_INFO(is_same_sm || is_sm110 || is_sm103,
734+
// SM120 and SM121 are architecturally identical
735+
bool is_sm120 = (inputs.gemm_config.sm_version == 120) && (sm_ == 120 || sm_ == 121);
736+
TLLM_CHECK_WITH_INFO(is_same_sm || is_sm110 || is_sm103 || is_sm120,
735737
"Using SM %d configuration for SM %d device",
736738
inputs.gemm_config.sm_version, sm_);
737739
TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr,

tests/moe/test_trtllm_cutlass_fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,8 +1093,8 @@ def dequant_mxfp4_batches(
10931093
("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)]
10941094
)
10951095
@pytest.mark.skipif(
1096-
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
1097-
reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120",
1096+
torch.cuda.get_device_capability()[0] not in [10, 11],
1097+
reason="MXFP8xMXFP4 is only supported on SM100 and SM110",
10981098
)
10991099
def test_moe_mxfp8_mxfp4(
11001100
batch_size,

0 commit comments

Comments
 (0)