Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
class ArtifactPath:
TRTLLM_GEN_FMHA: str = "c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/batched_gemm-32110eb-a15c257/"
"80848c8aa91d7bb650b762e7d5fa98abb16ed982/batched_gemm-32110eb-5262bae/"
)
TRTLLM_GEN_GEMM: str = (
"07a5f242a649533ff6885f87c42b2476a9e46233/gemm-c603ed2-434a6e1/"
Expand All @@ -63,6 +63,9 @@ class MetaInfoHash:
TRTLLM_GEN_FMHA: str = (
"0d124e546c8a2e9fa59499625e8a6d140a2465573d4a3944f9d29f29f73292fb"
)
TRTLLM_GEN_BMM: str = (
"23243b86451ba2a9c20e4456c14a86eb6c2204cd00a2972405f66643b677d01f"
)
DEEPGEMM: str = "69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
TRTLLM_GEN_GEMM: str = (
"50c5627324003c822efbdd1d368b1e569f4f67f4bb0a2fbb7397cd56c6d14c2a"
Expand All @@ -75,6 +78,7 @@ def download_artifacts() -> bool:
cubin_files = [
(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h"),
(ArtifactPath.TRTLLM_GEN_GEMM + "KernelMetaInfo", ".h"),
(ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo", ".h"),
]
for kernel in [
ArtifactPath.TRTLLM_GEN_FMHA,
Expand Down
17 changes: 11 additions & 6 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch

from ..artifacts import ArtifactPath
from ..artifacts import ArtifactPath, MetaInfoHash
from ..autotuner import (
AutoTuner,
DynamicTensorSpec,
Expand All @@ -33,6 +33,7 @@
from ..jit import JitSpec
from ..jit import env as jit_env
from ..jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
from ..jit.cubin_loader import get_cubin
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
from ..utils import _check_shape_dtype_device, register_custom_op, register_fake_op
from .utils import (
Expand Down Expand Up @@ -752,14 +753,17 @@ def cutlass_fused_moe(


def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
debug_cubin_path = (
jit_env.FLASHINFER_INCLUDE_DIR
/ "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/cubins"
)
import glob

include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"

metainfo = get_cubin(
f"{include_path}/flashinferMetaInfo", MetaInfoHash.TRTLLM_GEN_BMM, ".h"
)
assert metainfo, "KernelMetaInfo.h not found"

debug_cubin_files = [
Path(p) for p in glob.glob(str(debug_cubin_path / "Bmm_*.cpp"))
Path(p) for p in glob.glob(str(f"{ArtifactPath.TRTLLM_GEN_BMM}/Bmm_*.cpp"))
]

return gen_jit_spec(
Expand Down Expand Up @@ -790,6 +794,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
+ sm100a_nvcc_flags,
extra_ldflags=["-lcuda"],
extra_include_paths=[
jit_env.FLASHINFER_CACHE_DIR / "cubins" / include_path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be conflicting with #1462, considering using "FLASHINFER_CUBIN_DIR" instead

jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/include",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "trtllm/gen/CudaKernelLauncher.h"

#ifdef TLLM_GEN_EXPORT_INTERFACE
#include "KernelMetaInfo.h"
#include "flashinferMetaInfo.h"
#endif // TLLM_GEN_EXPORT_INTERFACE

#ifdef TLLM_GEN_BMM_CUBIN_PATH
Expand Down Expand Up @@ -509,7 +509,8 @@ BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const {

size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const {
#ifdef TLLM_GEN_EXPORT_INTERFACE
return tensorrt_llm::kernels::tllmGenBatchedGemmListLen;
return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) /
sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]);
#else
return 0;
#endif
Expand Down
Loading