@@ -236,7 +236,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
236
236
raise RuntimeError (f"Failed to generate Cutlass kernels: { e } " ) from e
237
237
238
238
return gen_jit_spec (
239
- "fused_moe_sm100 " ,
239
+ "fused_moe_cutlass_sm100 " ,
240
240
[
241
241
jit_env .FLASHINFER_CSRC_DIR
242
242
/ "nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu" ,
@@ -322,7 +322,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
322
322
323
323
@functools .cache
324
324
def get_cutlass_fused_moe_sm100_module (use_fast_build : bool = False ):
325
- gen_cutlass_fused_moe_sm100_module (use_fast_build ).build_and_load (
325
+ FusedMoeRunner = gen_cutlass_fused_moe_sm100_module (use_fast_build ).build_and_load (
326
326
class_name = "FusedMoeRunner"
327
327
)
328
328
@@ -385,15 +385,13 @@ def __init__(
385
385
)
386
386
387
387
if instance_key not in MoERunner .runner_dict :
388
- MoERunner .runner_dict [instance_key ] = (
389
- torch .classes .fused_moe_sm100 .FusedMoeRunner (
390
- x_dtype ,
391
- weight_dtype ,
392
- output_dtype ,
393
- use_deepseek_fp8_block_scale ,
394
- use_w4a8_group_scaling ,
395
- use_mxfp8_act_scaling ,
396
- )
388
+ MoERunner .runner_dict [instance_key ] = FusedMoeRunner (
389
+ x_dtype ,
390
+ weight_dtype ,
391
+ output_dtype ,
392
+ use_deepseek_fp8_block_scale ,
393
+ use_w4a8_group_scaling ,
394
+ use_mxfp8_act_scaling ,
397
395
)
398
396
399
397
self .fused_moe_runner = MoERunner .runner_dict [instance_key ]
@@ -819,7 +817,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
819
817
]
820
818
821
819
return gen_jit_spec (
822
- "fused_moe_sm100 " ,
820
+ "fused_moe_trtllm_sm100 " ,
823
821
[
824
822
jit_env .FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp" ,
825
823
jit_env .FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp" ,
0 commit comments