Skip to content

Commit 6334da9

Browse files
committed
added new gemm to aot module
1 parent c15ccdb commit 6334da9

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

flashinfer/aot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
gen_gemm_sm120_module,
5858
gen_gemm_sm120_module_cutlass_fp4,
5959
gen_trtllm_gen_gemm_module,
60+
gen_trtllm_low_latency_gemm_module,
6061
)
6162
from .jit.spdlog import gen_spdlog_module
6263
from .jit.mla import gen_mla_module
@@ -460,6 +461,7 @@ def gen_all_modules(
460461
)
461462
jit_specs.append(gen_mxfp8_quantization_sm100_module())
462463
jit_specs.append(gen_trtllm_gen_gemm_module())
464+
jit_specs.append(gen_trtllm_low_latency_gemm_module())
463465
jit_specs.append(gen_trtllm_gen_fused_moe_sm100_module())
464466
if has_sm100f:
465467
# Add TGV GEMM modules compiled with SM100f flags for both bf16 and fp16

flashinfer/jit/gemm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
gen_gemm_sm100_module,
2323
gen_gemm_sm120_module,
2424
gen_trtllm_gen_gemm_module,
25+
gen_trtllm_low_latency_gemm_module,
2526
gen_tgv_gemm_sm10x_module,
2627
gen_gemm_sm90_module,
2728
)
@@ -35,6 +36,7 @@
3536
"gen_gemm_sm100_module",
3637
"gen_gemm_sm120_module",
3738
"gen_trtllm_gen_gemm_module",
39+
"gen_trtllm_low_latency_gemm_module",
3840
"gen_tgv_gemm_sm10x_module",
3941
"gen_gemm_sm90_module",
4042
"gen_deepgemm_sm100_module",

0 commit comments

Comments
 (0)