Skip to content

Commit e0aa4e7

Browse files
cyx-6joker-eph
andauthored
refactor: Improved metainfo for trtllm-gen kernels (#1328)
<!-- .github/pull_request_template.md --> ## 📌 Description Move trtllm-gen batched-gemm and gemm metainfo headers into cache directory and link when jit. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 30319e7 commit e0aa4e7

File tree

7 files changed

+45
-34764
lines changed

7 files changed

+45
-34764
lines changed

flashinfer/artifacts.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
5050
class ArtifactPath:
5151
TRTLLM_GEN_FMHA: str = "c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/fmha/trtllm-gen/"
5252
TRTLLM_GEN_BMM: str = (
53-
"c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/batched_gemm-32110eb-a15c257/"
53+
"5d347c6234c9f0e7f1ab6519ea933183b48216ed/batched_gemm-32110eb-5262bae/"
5454
)
5555
TRTLLM_GEN_GEMM: str = (
56-
"07a5f242a649533ff6885f87c42b2476a9e46233/gemm-c603ed2-434a6e1/"
56+
"5d347c6234c9f0e7f1ab6519ea933183b48216ed/gemm-32110eb-434a6e1/"
5757
)
5858
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
5959
DEEPGEMM: str = "d25901733420c7cddc1adf799b0d4639ed1e162f/deep-gemm/"
@@ -63,9 +63,12 @@ class MetaInfoHash:
6363
TRTLLM_GEN_FMHA: str = (
6464
"0d124e546c8a2e9fa59499625e8a6d140a2465573d4a3944f9d29f29f73292fb"
6565
)
66+
TRTLLM_GEN_BMM: str = (
67+
"aae02e5703ee0ce696c4b3a1f2a32936fcc960dcb69fdef52b6d0f8a7b673000"
68+
)
6669
DEEPGEMM: str = "69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
6770
TRTLLM_GEN_GEMM: str = (
68-
"50c5627324003c822efbdd1d368b1e569f4f67f4bb0a2fbb7397cd56c6d14c2a"
71+
"a00ef9d834cb66c724ec7c72337bc955dc53070a65a6f68b34f852d144fa6ea3"
6972
)
7073

7174

@@ -74,7 +77,8 @@ def download_artifacts() -> bool:
7477
os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = "1"
7578
cubin_files = [
7679
(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h"),
77-
(ArtifactPath.TRTLLM_GEN_GEMM + "KernelMetaInfo", ".h"),
80+
(ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo", ".h"),
81+
(ArtifactPath.TRTLLM_GEN_BMM + "include/flashinferMetaInfo", ".h"),
7882
]
7983
for kernel in [
8084
ArtifactPath.TRTLLM_GEN_FMHA,

flashinfer/deep_gemm.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .artifacts import ArtifactPath, MetaInfoHash
3838
from .cuda_utils import checkCudaErrors
3939
from .jit.cubin_loader import get_cubin
40-
from .jit.env import FLASHINFER_CACHE_DIR
40+
from .jit.env import FLASHINFER_CUBIN_DIR
4141
from .utils import ceil_div, round_up
4242

4343

@@ -907,11 +907,7 @@ def load_all():
907907
continue
908908
symbol, sha256 = KERNEL_MAP[cubin_name]
909909
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
910-
path = (
911-
FLASHINFER_CACHE_DIR
912-
/ "cubins"
913-
/ f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
914-
)
910+
path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
915911
assert path.exists()
916912
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
917913

@@ -925,9 +921,7 @@ def load(name: str, code: str) -> SM100FP8GemmRuntime:
925921
return RUNTIME_CACHE[cubin_name]
926922
symbol, sha256 = KERNEL_MAP[cubin_name]
927923
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
928-
path = (
929-
FLASHINFER_CACHE_DIR / "cubins" / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
930-
)
924+
path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
931925
assert path.exists()
932926
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
933927
return RUNTIME_CACHE[cubin_name]
@@ -1460,7 +1454,7 @@ def init_indices(self):
14601454
assert get_cubin(indice_path, self.sha256, file_extension=".json"), (
14611455
"cubin kernel map file not found, nor downloaded with matched sha256"
14621456
)
1463-
path = FLASHINFER_CACHE_DIR / "cubins" / f"{indice_path}.json"
1457+
path = FLASHINFER_CUBIN_DIR / f"{indice_path}.json"
14641458
assert path.exists()
14651459
with open(path, "r") as f:
14661460
self.indice = json.load(f)

flashinfer/fused_moe/core.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616

1717
import functools
1818
from enum import IntEnum
19-
from pathlib import Path
2019
from types import SimpleNamespace
2120
from typing import Any, Dict, List, Optional, Tuple, Union
2221

2322
import torch
2423

25-
from ..artifacts import ArtifactPath
24+
from ..artifacts import ArtifactPath, MetaInfoHash
2625
from ..autotuner import (
2726
AutoTuner,
2827
DynamicTensorSpec,
@@ -33,6 +32,7 @@
3332
from ..jit import JitSpec
3433
from ..jit import env as jit_env
3534
from ..jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
35+
from ..jit.cubin_loader import get_cubin
3636
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
3737
from ..utils import (
3838
_check_shape_dtype_device,
@@ -819,15 +819,18 @@ def cutlass_fused_moe(
819819

820820

821821
def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
822-
debug_cubin_path = (
823-
jit_env.FLASHINFER_INCLUDE_DIR
824-
/ "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/cubins"
822+
# Fetch "flashinferMetaInfo.h" from the online kernel cache. This file
823+
# contains the `tllmGenBatchedGemmList` as the list of available kernels
824+
# online. It is included when compiling `trtllm_fused_moe_runner.cu`, etc.
825+
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
826+
header_name = "flashinferMetaInfo"
827+
828+
# use `get_cubin` to get "flashinferMetaInfo.h"
829+
metainfo = get_cubin(
830+
f"{include_path}/{header_name}", MetaInfoHash.TRTLLM_GEN_BMM, ".h"
825831
)
826-
import glob
827-
828-
debug_cubin_files = [
829-
Path(p) for p in glob.glob(str(debug_cubin_path / "Bmm_*.cpp"))
830-
]
832+
# make sure "flashinferMetaInfo.h" is downloaded or cached
833+
assert metainfo, f"{header_name}.h not found"
831834

832835
return gen_jit_spec(
833836
"fused_moe_trtllm_sm100",
@@ -844,8 +847,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
844847
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_renormalize.cu",
845848
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_dev_kernel.cu",
846849
jit_env.FLASHINFER_CSRC_DIR / "trtllm_batched_gemm_runner.cu",
847-
]
848-
+ debug_cubin_files,
850+
],
849851
extra_cuda_cflags=[
850852
"-DTLLM_GEN_EXPORT_INTERFACE",
851853
"-DTLLM_ENABLE_CUDA",
@@ -857,6 +859,8 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
857859
+ sm100a_nvcc_flags,
858860
extra_ldflags=["-lcuda"],
859861
extra_include_paths=[
862+
# link "include" sub-directory in cache
863+
jit_env.FLASHINFER_CUBIN_DIR / include_path,
860864
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
861865
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/include",
862866
],

flashinfer/gemm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,19 @@ def get_gemm_sm100_module():
349349

350350

351351
def trtllm_gemm_gen_module() -> JitSpec:
352-
header_name = "KernelMetaInfo"
352+
# Fetch "flashinferMetaInfo.h" from the online kernel cache. This file
353+
# contains the `tllmGenGemmList` as the list of available kernels online.
354+
# It is included when compiling `trtllm_gemm_runner.cu`.
355+
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
356+
header_name = "flashinferMetaInfo"
357+
358+
# use `get_cubin` to get "flashinferMetaInfo.h"
353359
metainfo = get_cubin(
354-
f"{ArtifactPath.TRTLLM_GEN_GEMM}/{header_name}",
360+
f"{include_path}/{header_name}",
355361
MetaInfoHash.TRTLLM_GEN_GEMM,
356362
".h",
357363
)
364+
# make sure "flashinferMetaInfo.h" is downloaded or cached
358365
assert metainfo, f"{header_name}.h not found"
359366
return gen_jit_spec(
360367
"trtllm_gemm",
@@ -367,11 +374,8 @@ def trtllm_gemm_gen_module() -> JitSpec:
367374
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
368375
]
369376
+ sm100a_nvcc_flags,
370-
extra_include_paths=[
371-
jit_env.FLASHINFER_CACHE_DIR / "cubins" / ArtifactPath.TRTLLM_GEN_GEMM,
372-
jit_env.FLASHINFER_INCLUDE_DIR
373-
/ "flashinfer/trtllm/gemm/trtllmGen_gemm_export",
374-
],
377+
# link "include" sub-directory in cache
378+
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
375379
extra_ldflags=["-lcuda"],
376380
)
377381

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "trtllm/gen/CudaKernelLauncher.h"
2525

2626
#ifdef TLLM_GEN_EXPORT_INTERFACE
27-
#include "KernelMetaInfo.h"
27+
#include "flashinferMetaInfo.h"
2828
#endif // TLLM_GEN_EXPORT_INTERFACE
2929

3030
#ifdef TLLM_GEN_BMM_CUBIN_PATH
@@ -509,7 +509,8 @@ BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const {
509509

510510
size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const {
511511
#ifdef TLLM_GEN_EXPORT_INTERFACE
512-
return tensorrt_llm::kernels::tllmGenBatchedGemmListLen;
512+
return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) /
513+
sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]);
513514
#else
514515
return 0;
515516
#endif

0 commit comments

Comments
 (0)