Skip to content

Commit 89d146f

Browse files
authored
feat: Enable multiple fused-moe backends (#1472)
1 parent 7e98d8b commit 89d146f

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

flashinfer/fused_moe/core.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
236236
raise RuntimeError(f"Failed to generate Cutlass kernels: {e}") from e
237237

238238
return gen_jit_spec(
239-
"fused_moe_sm100",
239+
"fused_moe_cutlass_sm100",
240240
[
241241
jit_env.FLASHINFER_CSRC_DIR
242242
/ "nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu",
@@ -322,7 +322,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
322322

323323
@functools.cache
324324
def get_cutlass_fused_moe_sm100_module(use_fast_build: bool = False):
325-
gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load(
325+
FusedMoeRunner = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load(
326326
class_name="FusedMoeRunner"
327327
)
328328

@@ -385,15 +385,13 @@ def __init__(
385385
)
386386

387387
if instance_key not in MoERunner.runner_dict:
388-
MoERunner.runner_dict[instance_key] = (
389-
torch.classes.fused_moe_sm100.FusedMoeRunner(
390-
x_dtype,
391-
weight_dtype,
392-
output_dtype,
393-
use_deepseek_fp8_block_scale,
394-
use_w4a8_group_scaling,
395-
use_mxfp8_act_scaling,
396-
)
388+
MoERunner.runner_dict[instance_key] = FusedMoeRunner(
389+
x_dtype,
390+
weight_dtype,
391+
output_dtype,
392+
use_deepseek_fp8_block_scale,
393+
use_w4a8_group_scaling,
394+
use_mxfp8_act_scaling,
397395
)
398396

399397
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
@@ -819,7 +817,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
819817
]
820818

821819
return gen_jit_spec(
822-
"fused_moe_sm100",
820+
"fused_moe_trtllm_sm100",
823821
[
824822
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp",
825823
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp",

0 commit comments

Comments
 (0)