Skip to content

Commit 1aedbf2

Browse files
committed
init
1 parent 661db1f commit 1aedbf2

File tree

6 files changed

+33
-19808
lines changed

6 files changed

+33
-19808
lines changed

flashinfer/fused_moe.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .jit import JitSpec
4040
from .jit import env as jit_env
4141
from .jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
42+
from .jit.cubin_loader import get_cubin
4243
from .utils import _check_shape_dtype_device, register_custom_op, register_fake_op
4344

4445

@@ -773,6 +774,13 @@ def cutlass_fused_moe(
773774

774775

775776
def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
777+
hash = "f5deee96023f1d74b1ff71ac69f782a96741a053"
778+
metainfo = get_cubin(
779+
f"{hash}/batched_gemm-c603ed2-3fa89e1/include/KernelMetaInfo",
780+
"d789c63aaeee1aa0a68ebf22fa693b6b82a7c2319bd933a00a10306ca08d9e0e",
781+
".h",
782+
)
783+
assert metainfo, "KernelMetaInfo.h not found"
776784
return gen_jit_spec(
777785
"fused_moe_sm100",
778786
[
@@ -790,6 +798,13 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
790798
"-DENABLE_FP4",
791799
]
792800
+ sm100a_nvcc_flags,
801+
extra_include_paths=[
802+
jit_env.FLASHINFER_CACHE_DIR
803+
/ "cubins"
804+
/ hash
805+
/ "batched_gemm-c603ed2-3fa89e1"
806+
/ "include"
807+
],
793808
extra_ldflags=["-lcuda"],
794809
)
795810

flashinfer/jit/attention/pytorch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .. import env as jit_env
2424
from ..core import JitSpec, gen_jit_spec, logger, sm90a_nvcc_flags, sm100a_nvcc_flags
25+
from ..cubin_loader import get_cubin
2526
from ..utils import (
2627
dtype_map,
2728
filename_safe_dtype_map,
@@ -1482,12 +1483,27 @@ def gen_fmha_cutlass_sm100a_module(
14821483

14831484

14841485
def trtllm_gen_fmha_module():
1486+
hash = "5f2779e6df822bc0b26940b6d3b0059c86f0a6a1"
1487+
metainfo = get_cubin(
1488+
f"{hash}/fmha/trtllm-gen/include/flashInferMetaInfo",
1489+
"11f31dc81f996e39c3f1d85d773864c9113c5837619e21418a846befa4f8dddd",
1490+
".h",
1491+
)
1492+
assert metainfo, "flashInferMetaInfo.h not found"
14851493
return gen_jit_spec(
14861494
"fmha_gen",
14871495
[
14881496
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_runner.cu",
14891497
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_kernel_launcher.cu",
14901498
],
1499+
extra_include_paths=[
1500+
jit_env.FLASHINFER_CACHE_DIR
1501+
/ "cubins"
1502+
/ hash
1503+
/ "fmha"
1504+
/ "trtllm-gen"
1505+
/ "include"
1506+
],
14911507
extra_ldflags=["-lcuda"],
14921508
)
14931509

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
645645

646646
auto fiModuleLoadData = [&](CUmodule* module) {
647647
const std::string sha256 = config.mHash ? config.mHash : "";
648-
const std::string pipeline_hash = "991e7438224199de85ef08a2730ce18c12b4e0aa";
648+
const std::string pipeline_hash = "f5deee96023f1d74b1ff71ac69f782a96741a053";
649649
const std::string cubin_path = pipeline_hash + "/" + std::string("batched_gemm-") +
650650
TLLM_GEN_COMMIT + "-" + TLLM_GEN_BATCHED_GEMM_CONFIG_HASH + "/";
651651
std::string fname_cubin = config.mFunctionName;

0 commit comments

Comments
 (0)