Skip to content

Commit fdfe07a

Browse files
authored
refactor: download trtllm gemm metadata from server (#1378)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description refactor: download trtllm gemm metadata from server ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 455294d commit fdfe07a

File tree

3 files changed

+22
-2431
lines changed

3 files changed

+22
-2431
lines changed

β€Žflashinfer/artifacts.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class ArtifactPath:
5353
"991e7438224199de85ef08a2730ce18c12b4e0aa/batched_gemm-c603ed2-2dc78d9/"
5454
)
5555
TRTLLM_GEN_GEMM: str = (
56-
"fffd607babb0844f24225997409747ca38229333/gemm-c603ed2-f2b0c24/"
56+
"07a5f242a649533ff6885f87c42b2476a9e46233/gemm-c603ed2-434a6e1/"
5757
)
5858
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
5959
DEEPGEMM: str = "d25901733420c7cddc1adf799b0d4639ed1e162f/deep-gemm/"
@@ -64,12 +64,18 @@ class MetaInfoHash:
6464
"8c5630020c0452fb1cd1ea7e3b8fdbb7bf94f71bd899ed5b704a490bdb4f7368"
6565
)
6666
DEEPGEMM: str = "69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
67+
TRTLLM_GEN_GEMM: str = (
68+
"50c5627324003c822efbdd1d368b1e569f4f67f4bb0a2fbb7397cd56c6d14c2a"
69+
)
6770

6871

6972
def download_artifacts() -> bool:
7073
env_backup = os.environ.get("FLASHINFER_CUBIN_CHECKSUM_DISABLED", None)
7174
os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = "1"
72-
cubin_files = [(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h")]
75+
cubin_files = [
76+
(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h"),
77+
(ArtifactPath.TRTLLM_GEN_GEMM + "KernelMetaInfo", ".h"),
78+
]
7379
for kernel in [
7480
ArtifactPath.TRTLLM_GEN_FMHA,
7581
ArtifactPath.TRTLLM_GEN_BMM,

β€Žflashinfer/gemm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch
2626
import torch.nn.functional as F
2727

28-
from .artifacts import ArtifactPath
28+
from .artifacts import ArtifactPath, MetaInfoHash
2929
from .autotuner import (
3030
AutoTuner,
3131
ConstraintSpec,
@@ -38,6 +38,7 @@
3838
get_last_power_of_2_num_tokens_buckets,
3939
last_positive_power_of_2,
4040
)
41+
from .jit.cubin_loader import get_cubin
4142

4243
CUDNN_AVAILABLE = False
4344
try:
@@ -302,6 +303,13 @@ def get_gemm_sm100_module():
302303

303304

304305
def trtllm_gemm_gen_module() -> JitSpec:
306+
header_name = "KernelMetaInfo"
307+
metainfo = get_cubin(
308+
f"{ArtifactPath.TRTLLM_GEN_GEMM}/{header_name}",
309+
MetaInfoHash.TRTLLM_GEN_GEMM,
310+
".h",
311+
)
312+
assert metainfo, f"{header_name}.h not found"
305313
return gen_jit_spec(
306314
"trtllm_gemm",
307315
[
@@ -313,6 +321,11 @@ def trtllm_gemm_gen_module() -> JitSpec:
313321
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
314322
]
315323
+ sm100a_nvcc_flags,
324+
extra_include_paths=[
325+
jit_env.FLASHINFER_CACHE_DIR / "cubins" / ArtifactPath.TRTLLM_GEN_GEMM,
326+
jit_env.FLASHINFER_INCLUDE_DIR
327+
/ "flashinfer/trtllm/gemm/trtllmGen_gemm_export",
328+
],
316329
extra_ldflags=["-lcuda"],
317330
)
318331

0 commit comments

Comments
Β (0)