Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
class ArtifactPath:
TRTLLM_GEN_FMHA: str = "c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/batched_gemm-32110eb-a15c257/"
"5d347c6234c9f0e7f1ab6519ea933183b48216ed/batched_gemm-32110eb-5262bae/"
)
TRTLLM_GEN_GEMM: str = (
"07a5f242a649533ff6885f87c42b2476a9e46233/gemm-c603ed2-434a6e1/"
"5d347c6234c9f0e7f1ab6519ea933183b48216ed/gemm-32110eb-434a6e1/"
)
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
DEEPGEMM: str = "d25901733420c7cddc1adf799b0d4639ed1e162f/deep-gemm/"
Expand All @@ -63,9 +63,12 @@ class MetaInfoHash:
TRTLLM_GEN_FMHA: str = (
"0d124e546c8a2e9fa59499625e8a6d140a2465573d4a3944f9d29f29f73292fb"
)
TRTLLM_GEN_BMM: str = (
"aae02e5703ee0ce696c4b3a1f2a32936fcc960dcb69fdef52b6d0f8a7b673000"
)
DEEPGEMM: str = "69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
TRTLLM_GEN_GEMM: str = (
"50c5627324003c822efbdd1d368b1e569f4f67f4bb0a2fbb7397cd56c6d14c2a"
"a00ef9d834cb66c724ec7c72337bc955dc53070a65a6f68b34f852d144fa6ea3"
)


Expand All @@ -74,7 +77,8 @@ def download_artifacts() -> bool:
os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = "1"
cubin_files = [
(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h"),
(ArtifactPath.TRTLLM_GEN_GEMM + "KernelMetaInfo", ".h"),
(ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo", ".h"),
(ArtifactPath.TRTLLM_GEN_BMM + "include/flashinferMetaInfo", ".h"),
]
for kernel in [
ArtifactPath.TRTLLM_GEN_FMHA,
Expand Down
9 changes: 4 additions & 5 deletions flashinfer/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .artifacts import ArtifactPath, MetaInfoHash
from .cuda_utils import checkCudaErrors
from .jit.cubin_loader import get_cubin
from .jit.env import FLASHINFER_CACHE_DIR
from .jit.env import FLASHINFER_CUBIN_DIR
from .utils import ceil_div, round_up


Expand Down Expand Up @@ -908,8 +908,7 @@ def load_all():
symbol, sha256 = KERNEL_MAP[cubin_name]
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
path = (
FLASHINFER_CACHE_DIR
/ "cubins"
FLASHINFER_CUBIN_DIR
/ f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
)
assert path.exists()
Expand All @@ -926,7 +925,7 @@ def load(name: str, code: str) -> SM100FP8GemmRuntime:
symbol, sha256 = KERNEL_MAP[cubin_name]
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
path = (
FLASHINFER_CACHE_DIR / "cubins" / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
)
assert path.exists()
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
Expand Down Expand Up @@ -1460,7 +1459,7 @@ def init_indices(self):
assert get_cubin(indice_path, self.sha256, file_extension=".json"), (
"cubin kernel map file not found, nor downloaded with matched sha256"
)
path = FLASHINFER_CACHE_DIR / "cubins" / f"{indice_path}.json"
path = FLASHINFER_CUBIN_DIR / f"{indice_path}.json"
assert path.exists()
with open(path, "r") as f:
self.indice = json.load(f)
Expand Down
28 changes: 16 additions & 12 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

import functools
from enum import IntEnum
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

from ..artifacts import ArtifactPath
from ..artifacts import ArtifactPath, MetaInfoHash
from ..autotuner import (
AutoTuner,
DynamicTensorSpec,
Expand All @@ -33,6 +32,7 @@
from ..jit import JitSpec
from ..jit import env as jit_env
from ..jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
from ..jit.cubin_loader import get_cubin
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
from ..utils import (
_check_shape_dtype_device,
Expand Down Expand Up @@ -819,15 +819,18 @@ def cutlass_fused_moe(


def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
debug_cubin_path = (
jit_env.FLASHINFER_INCLUDE_DIR
/ "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/cubins"
# Fetch "flashinferMetaInfo.h" from the online kernel cache. This file
# contains the `tllmGenBatchedGemmList` as the list of available kernels
# online. It is included when compiling `trtllm_fused_moe_runner.cu`, etc.
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
header_name = "flashinferMetaInfo"

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}", MetaInfoHash.TRTLLM_GEN_BMM, ".h"
)
import glob

debug_cubin_files = [
Path(p) for p in glob.glob(str(debug_cubin_path / "Bmm_*.cpp"))
]
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"

return gen_jit_spec(
"fused_moe_trtllm_sm100",
Expand All @@ -844,8 +847,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_renormalize.cu",
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_dev_kernel.cu",
jit_env.FLASHINFER_CSRC_DIR / "trtllm_batched_gemm_runner.cu",
]
+ debug_cubin_files,
],
extra_cuda_cflags=[
"-DTLLM_GEN_EXPORT_INTERFACE",
"-DTLLM_ENABLE_CUDA",
Expand All @@ -857,6 +859,8 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
+ sm100a_nvcc_flags,
extra_ldflags=["-lcuda"],
extra_include_paths=[
# link "include" sub-directory in cache
jit_env.FLASHINFER_CUBIN_DIR / include_path,
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/include",
],
Expand Down
18 changes: 11 additions & 7 deletions flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,19 @@ def get_gemm_sm100_module():


def trtllm_gemm_gen_module() -> JitSpec:
header_name = "KernelMetaInfo"
# Fetch "flashinferMetaInfo.h" from the online kernel cache. This file
# contains the `tllmGenGemmList` as the list of available kernels online.
# It is included when compiling `trtllm_gemm_runner.cu`.
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
header_name = "flashinferMetaInfo"

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{ArtifactPath.TRTLLM_GEN_GEMM}/{header_name}",
f"{include_path}/{header_name}",
MetaInfoHash.TRTLLM_GEN_GEMM,
".h",
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
return gen_jit_spec(
"trtllm_gemm",
Expand All @@ -367,11 +374,8 @@ def trtllm_gemm_gen_module() -> JitSpec:
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
]
+ sm100a_nvcc_flags,
extra_include_paths=[
jit_env.FLASHINFER_CACHE_DIR / "cubins" / ArtifactPath.TRTLLM_GEN_GEMM,
jit_env.FLASHINFER_INCLUDE_DIR
/ "flashinfer/trtllm/gemm/trtllmGen_gemm_export",
],
# link "include" sub-directory in cache
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
extra_ldflags=["-lcuda"],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "trtllm/gen/CudaKernelLauncher.h"

#ifdef TLLM_GEN_EXPORT_INTERFACE
#include "KernelMetaInfo.h"
#include "flashinferMetaInfo.h"
#endif // TLLM_GEN_EXPORT_INTERFACE

#ifdef TLLM_GEN_BMM_CUBIN_PATH
Expand Down Expand Up @@ -509,7 +509,8 @@ BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const {

size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const {
#ifdef TLLM_GEN_EXPORT_INTERFACE
return tensorrt_llm::kernels::tllmGenBatchedGemmListLen;
return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) /
sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]);
#else
return 0;
#endif
Expand Down
Loading
Loading