Skip to content

Commit c5a6ca1

Browse files
committed
rebase
1 parent fc88829 commit c5a6ca1

File tree

4 files changed

+19
-34736
lines changed

4 files changed

+19
-34736
lines changed

flashinfer/artifacts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
5050
class ArtifactPath:
5151
TRTLLM_GEN_FMHA: str = "c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/fmha/trtllm-gen/"
5252
TRTLLM_GEN_BMM: str = (
53-
"c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/batched_gemm-32110eb-a15c257/"
53+
"80848c8aa91d7bb650b762e7d5fa98abb16ed982/batched_gemm-32110eb-5262bae/"
5454
)
5555
TRTLLM_GEN_GEMM: str = (
5656
"07a5f242a649533ff6885f87c42b2476a9e46233/gemm-c603ed2-434a6e1/"
@@ -63,6 +63,9 @@ class MetaInfoHash:
6363
TRTLLM_GEN_FMHA: str = (
6464
"0d124e546c8a2e9fa59499625e8a6d140a2465573d4a3944f9d29f29f73292fb"
6565
)
66+
TRTLLM_GEN_BMM: str = (
67+
"23243b86451ba2a9c20e4456c14a86eb6c2204cd00a2972405f66643b677d01f"
68+
)
6669
DEEPGEMM: str = "69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
6770
TRTLLM_GEN_GEMM: str = (
6871
"50c5627324003c822efbdd1d368b1e569f4f67f4bb0a2fbb7397cd56c6d14c2a"
@@ -75,6 +78,7 @@ def download_artifacts() -> bool:
7578
cubin_files = [
7679
(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h"),
7780
(ArtifactPath.TRTLLM_GEN_GEMM + "KernelMetaInfo", ".h"),
81+
(ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo", ".h"),
7882
]
7983
for kernel in [
8084
ArtifactPath.TRTLLM_GEN_FMHA,

flashinfer/fused_moe/core.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import torch
2424

25-
from ..artifacts import ArtifactPath
25+
from ..artifacts import ArtifactPath, MetaInfoHash
2626
from ..autotuner import (
2727
AutoTuner,
2828
DynamicTensorSpec,
@@ -33,6 +33,7 @@
3333
from ..jit import JitSpec
3434
from ..jit import env as jit_env
3535
from ..jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
36+
from ..jit.cubin_loader import get_cubin
3637
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
3738
from ..utils import _check_shape_dtype_device, register_custom_op, register_fake_op
3839
from .utils import (
@@ -752,14 +753,17 @@ def cutlass_fused_moe(
752753

753754

754755
def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
755-
debug_cubin_path = (
756-
jit_env.FLASHINFER_INCLUDE_DIR
757-
/ "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/cubins"
758-
)
759756
import glob
760757

758+
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
759+
760+
metainfo = get_cubin(
761+
f"{include_path}/flashinferMetaInfo", MetaInfoHash.TRTLLM_GEN_BMM, ".h"
762+
)
763+
assert metainfo, "KernelMetaInfo.h not found"
764+
761765
debug_cubin_files = [
762-
Path(p) for p in glob.glob(str(debug_cubin_path / "Bmm_*.cpp"))
766+
Path(p) for p in glob.glob(str(f"{ArtifactPath.TRTLLM_GEN_BMM}/Bmm_*.cpp"))
763767
]
764768

765769
return gen_jit_spec(
@@ -790,6 +794,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
790794
+ sm100a_nvcc_flags,
791795
extra_ldflags=["-lcuda"],
792796
extra_include_paths=[
797+
jit_env.FLASHINFER_CACHE_DIR / "cubins" / include_path,
793798
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
794799
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/include",
795800
],

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "trtllm/gen/CudaKernelLauncher.h"
2525

2626
#ifdef TLLM_GEN_EXPORT_INTERFACE
27-
#include "KernelMetaInfo.h"
27+
#include "flashinferMetaInfo.h"
2828
#endif // TLLM_GEN_EXPORT_INTERFACE
2929

3030
#ifdef TLLM_GEN_BMM_CUBIN_PATH
@@ -509,7 +509,8 @@ BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const {
509509

510510
size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const {
511511
#ifdef TLLM_GEN_EXPORT_INTERFACE
512-
return tensorrt_llm::kernels::tllmGenBatchedGemmListLen;
512+
return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) /
513+
sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]);
513514
#else
514515
return 0;
515516
#endif

0 commit comments

Comments
 (0)