Skip to content

Commit d517373

Browse files
cyx-6yzh119
andauthored
misc: Artifact downloading and single sourced artifact path (#1369)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR adds the features of downloading complete artifacts and makes artifacts path single sourced in python. cc: @yyihuang @zhyncs <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Yaxing Cai <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent caf7d10 commit d517373

File tree

10 files changed

+203
-47
lines changed

10 files changed

+203
-47
lines changed

csrc/cudnn_sdpa_kernel_launcher.cu

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
#include "cudnn_sdpa_utils.h"
3030
#include "pytorch_extension_utils.h"
3131

32+
#ifdef CUDNN_SDPA_CUBIN_PATH
33+
static const std::string cudnn_sdpa_cubin_path = std::string(CUDNN_SDPA_CUBIN_PATH);
34+
#else
35+
static_assert(false, "CUDNN_SDPA_CUBIN_PATH macro is not defined when compiling");
36+
#endif
37+
3238
namespace flashinfer {
3339

3440
namespace cudnn_sdpa_kernel_launcher {
@@ -77,19 +83,15 @@ enum PrefillType {
7783
};
7884

7985
void init_cudnn_cubin(std::map<KernelType, std::string>& cubin_map) {
80-
cubin_map[PREFILL] = getCubin(
81-
"4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
82-
"cudnn_sm100_fprop_sdpa_prefill_d128_bf16",
83-
"ff14e8dcfc04d9b3a912dd44056be37d9aa8a85976e0070494ca0cce0524f2a1");
84-
85-
cubin_map[DECODE] = getCubin(
86-
"4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/cudnn_sm100_fprop_sdpa_decode_d128_bf16",
87-
"e7ce0408b4c3a36c42616498228534ee64cab785ef570af5741deaf9dd1b475c");
88-
89-
cubin_map[PREFILL_DEEPSEEK] = getCubin(
90-
"4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
91-
"cudnn_sm100_fprop_sdpa_prefill_d192_bf16",
92-
"2190967b8733e193cdcecc054eeb7c2907080a158a33fe7ba2004523a4aff6f9");
86+
cubin_map[PREFILL] = getCubin(cudnn_sdpa_cubin_path + "cudnn_sm100_fprop_sdpa_prefill_d128_bf16",
87+
"ff14e8dcfc04d9b3a912dd44056be37d9aa8a85976e0070494ca0cce0524f2a1");
88+
89+
cubin_map[DECODE] = getCubin(cudnn_sdpa_cubin_path + "cudnn_sm100_fprop_sdpa_decode_d128_bf16",
90+
"e7ce0408b4c3a36c42616498228534ee64cab785ef570af5741deaf9dd1b475c");
91+
92+
cubin_map[PREFILL_DEEPSEEK] =
93+
getCubin(cudnn_sdpa_cubin_path + "cudnn_sm100_fprop_sdpa_prefill_d192_bf16",
94+
"2190967b8733e193cdcecc054eeb7c2907080a158a33fe7ba2004523a4aff6f9");
9395
}
9496

9597
auto get_cudnn_cubin(KernelType kernel_type) -> std::string {

flashinfer/__main__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
# flashinfer-cli
18+
import argparse
19+
20+
from .artifacts import download_artifacts
21+
22+
if __name__ == "__main__":
23+
parser = argparse.ArgumentParser("FlashInfer CLI")
24+
parser.add_argument(
25+
"--download-cubin", action="store_true", help="Download artifacts"
26+
)
27+
28+
args = parser.parse_args()
29+
30+
if args.download_cubin:
31+
if download_artifacts():
32+
print("✅ All cubin download tasks completed successfully.")
33+
else:
34+
print("❌ Some cubin download tasks failed.")

flashinfer/artifacts.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import re
19+
import time
20+
from concurrent.futures import ThreadPoolExecutor, as_completed
21+
22+
import requests
23+
24+
from .jit.core import logger
25+
from .jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY, get_cubin
26+
27+
28+
def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
29+
for attempt in range(1, retries + 1):
30+
try:
31+
response = requests.get(source, timeout=timeout)
32+
response.raise_for_status()
33+
hrefs = re.findall(r'\<a href=".*\.cubin">', response.text)
34+
files = [(h[9:-8], ".cubin") for h in hrefs]
35+
return files
36+
37+
except requests.exceptions.RequestException as e:
38+
logger.warning(
39+
f"Fetching available files {source}: attempt {attempt} failed: {e}"
40+
)
41+
42+
if attempt < retries:
43+
logger.info(f"Retrying in {delay} seconds...")
44+
time.sleep(delay)
45+
else:
46+
logger.error("Max retries reached. Fetch failed.")
47+
return []
48+
49+
50+
class ArtifactPath:
51+
TRTLLM_GEN_FMHA: str = "52e676342c67a3772e06f10b84600044c0c22b76/fmha/trtllm-gen/"
52+
TRTLLM_GEN_BMM: str = (
53+
"991e7438224199de85ef08a2730ce18c12b4e0aa/batched_gemm-c603ed2-2dc78d9/"
54+
)
55+
TRTLLM_GEN_GEMM: str = (
56+
"fffd607babb0844f24225997409747ca38229333/gemm-c603ed2-f2b0c24/"
57+
)
58+
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
59+
DEEPGEMM: str = "d25901733420c7cddc1adf799b0d4639ed1e162f/deep-gemm/"
60+
61+
62+
class MetaInfoHash:
63+
TRTLLM_GEN_FMHA: str = (
64+
"8c5630020c0452fb1cd1ea7e3b8fdbb7bf94f71bd899ed5b704a490bdb4f7368"
65+
)
66+
DEEPGEMM: str = "69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
67+
68+
69+
def download_artifacts() -> bool:
70+
env_backup = os.environ.get("FLASHINFER_CUBIN_CHECKSUM_DISABLED", None)
71+
os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = "1"
72+
cubin_files = [(ArtifactPath.TRTLLM_GEN_FMHA + "flashInferMetaInfo", ".h")]
73+
for kernel in [
74+
ArtifactPath.TRTLLM_GEN_FMHA,
75+
ArtifactPath.TRTLLM_GEN_BMM,
76+
ArtifactPath.TRTLLM_GEN_GEMM,
77+
ArtifactPath.DEEPGEMM,
78+
]:
79+
cubin_files += [
80+
(kernel + name, extension)
81+
for name, extension in get_available_cubin_files(
82+
FLASHINFER_CUBINS_REPOSITORY + "/" + kernel
83+
)
84+
]
85+
pool = ThreadPoolExecutor(4)
86+
futures = []
87+
for name, extension in cubin_files:
88+
ret = pool.submit(get_cubin, name, "", extension)
89+
futures.append(ret)
90+
results = []
91+
for ret in as_completed(futures):
92+
result = ret.result()
93+
results.append(result)
94+
all_success = all(results)
95+
if not env_backup:
96+
os.environ.pop("FLASHINFER_CUBIN_CHECKSUM_DISABLED")
97+
else:
98+
os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = env_backup
99+
100+
return all_success

flashinfer/deep_gemm.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import cuda.bindings.driver as cbd
3636
import torch
3737

38+
from .artifacts import ArtifactPath, MetaInfoHash
3839
from .cuda_utils import checkCudaErrors
3940
from .jit.cubin_loader import get_cubin
4041
from .jit.env import FLASHINFER_CACHE_DIR
@@ -887,17 +888,17 @@ def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
887888
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
888889

889890

890-
_artifact_hash = "d25901733420c7cddc1adf799b0d4639ed1e162f"
891-
892-
893891
def load_all():
894892
for cubin_name in KERNEL_MAP:
895893
if cubin_name in RUNTIME_CACHE:
896894
continue
897895
symbol, sha256 = KERNEL_MAP[cubin_name]
898-
cubin_prefix = f"{_artifact_hash}/deep-gemm/"
899-
get_cubin(cubin_prefix + cubin_name, sha256)
900-
path = FLASHINFER_CACHE_DIR / "cubins" / f"{cubin_prefix + cubin_name}.cubin"
896+
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
897+
path = (
898+
FLASHINFER_CACHE_DIR
899+
/ "cubins"
900+
/ f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
901+
)
901902
assert path.exists()
902903
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
903904

@@ -910,9 +911,10 @@ def load(name: str, code: str) -> SM100FP8GemmRuntime:
910911
if cubin_name in RUNTIME_CACHE:
911912
return RUNTIME_CACHE[cubin_name]
912913
symbol, sha256 = KERNEL_MAP[cubin_name]
913-
cubin_prefix = f"{_artifact_hash}/deep-gemm/"
914-
get_cubin(cubin_prefix + cubin_name, sha256)
915-
path = FLASHINFER_CACHE_DIR / "cubins" / f"{cubin_prefix + cubin_name}.cubin"
914+
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
915+
path = (
916+
FLASHINFER_CACHE_DIR / "cubins" / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
917+
)
916918
assert path.exists()
917919
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
918920
return RUNTIME_CACHE[cubin_name]
@@ -1436,8 +1438,7 @@ def __init__(self, sha256: str):
14361438
self.indice = None
14371439

14381440
def init_indices(self):
1439-
cubin_prefix = f"{_artifact_hash}/deep-gemm/"
1440-
indice_path = cubin_prefix + "kernel_map"
1441+
indice_path = ArtifactPath.DEEPGEMM + "kernel_map"
14411442
assert get_cubin(
14421443
indice_path, self.sha256, file_extension=".json"
14431444
), "cubin kernel map file not found, nor downloaded with matched sha256"
@@ -1458,6 +1459,4 @@ def __getitem__(self, key):
14581459
return self.indice[key]
14591460

14601461

1461-
KERNEL_MAP = KernelMap(
1462-
"69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
1463-
)
1462+
KERNEL_MAP = KernelMap(MetaInfoHash.DEEPGEMM)

flashinfer/fused_moe/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import torch
2323

24+
from ..artifacts import ArtifactPath
2425
from ..autotuner import (
2526
AutoTuner,
2627
ConstraintSpec,
@@ -712,6 +713,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
712713
"-DENABLE_BF16",
713714
"-DENABLE_FP8",
714715
"-DENABLE_FP4",
716+
f'-DTLLM_GEN_BMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"',
715717
]
716718
+ sm100a_nvcc_flags,
717719
extra_ldflags=["-lcuda"],

flashinfer/gemm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch
2626
import torch.nn.functional as F
2727

28+
from .artifacts import ArtifactPath
2829
from .autotuner import (
2930
AutoTuner,
3031
ConstraintSpec,
@@ -309,6 +310,7 @@ def trtllm_gemm_gen_module() -> JitSpec:
309310
extra_cuda_cflags=[
310311
"-DTLLM_GEN_EXPORT_INTERFACE",
311312
"-DTLLM_ENABLE_CUDA",
313+
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
312314
]
313315
+ sm100a_nvcc_flags,
314316
extra_ldflags=["-lcuda"],

flashinfer/jit/attention/pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jinja2
2121
import torch
2222

23+
from ...artifacts import ArtifactPath, MetaInfoHash
2324
from .. import env as jit_env
2425
from ..core import JitSpec, gen_jit_spec, logger, sm90a_nvcc_flags, sm100a_nvcc_flags
2526
from ..utils import (
@@ -1494,6 +1495,10 @@ def trtllm_gen_fmha_module():
14941495
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_kernel_launcher.cu",
14951496
],
14961497
extra_ldflags=["-lcuda"],
1498+
extra_cuda_cflags=[
1499+
f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"',
1500+
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"',
1501+
],
14971502
)
14981503

14991504

@@ -1593,4 +1598,7 @@ def cudnn_fmha_gen_module():
15931598
"fmha_cudnn_gen",
15941599
[jit_env.FLASHINFER_CSRC_DIR / "cudnn_sdpa_kernel_launcher.cu"],
15951600
extra_ldflags=["-lcuda"],
1601+
extra_cuda_cflags=[
1602+
f'-DCUDNN_SDPA_CUBIN_PATH=\\"{ArtifactPath.CUDNN_SDPA}\\"',
1603+
],
15961604
)

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
#include "KernelMetaInfo.h"
2828
#endif // TLLM_GEN_EXPORT_INTERFACE
2929

30+
#ifdef TLLM_GEN_BMM_CUBIN_PATH
31+
static const std::string tllm_gen_bmm_cubin_path = std::string(TLLM_GEN_BMM_CUBIN_PATH);
32+
#else
33+
static_assert(false, "TLLM_GEN_BMM_CUBIN_PATH macro is not defined when compiling");
34+
#endif
35+
3036
namespace flashinfer::trtllm_cubin_loader {
3137
std::string getCubin(const std::string& kernelName, const std::string& sha256);
3238
}
@@ -645,14 +651,11 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
645651

646652
auto fiModuleLoadData = [&](CUmodule* module) {
647653
const std::string sha256 = config.mHash ? config.mHash : "";
648-
const std::string pipeline_hash = "991e7438224199de85ef08a2730ce18c12b4e0aa";
649-
const std::string cubin_path = pipeline_hash + "/" + std::string("batched_gemm-") +
650-
TLLM_GEN_COMMIT + "-" + TLLM_GEN_BATCHED_GEMM_CONFIG_HASH + "/";
651654
std::string fname_cubin = config.mFunctionName;
652655
if (!fname_cubin.empty()) {
653656
fname_cubin[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0])));
654657
}
655-
fname_cubin = cubin_path + fname_cubin;
658+
fname_cubin = tllm_gen_bmm_cubin_path + fname_cubin;
656659
std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256);
657660
cuModuleLoadData(&cuModule, cubin.c_str());
658661
};

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131
#include "fmhaRunnerParams.h"
3232
#include "kernelParams.h"
3333

34+
#ifdef TLLM_GEN_FMHA_CUBIN_PATH
35+
static const std::string tllm_gen_fmha_cubin_path = std::string(TLLM_GEN_FMHA_CUBIN_PATH);
36+
#else
37+
static_assert(false, "TLLM_GEN_FMHA_CUBIN_PATH macro is not defined when compiling");
38+
#endif
39+
40+
#ifdef TLLM_GEN_FMHA_METAINFO_HASH
41+
static const std::string tllm_gen_fmha_metainfo_hash = std::string(TLLM_GEN_FMHA_METAINFO_HASH);
42+
#else
43+
static_assert(false, "TLLM_GEN_FMHA_METAINFO_HASH macro is not defined when compiling");
44+
#endif
45+
3446
namespace flashinfer::trtllm_cubin_loader {
3547
std::string getCubin(const std::string& kernelName, const std::string& sha256);
3648
std::string getMetaInfo(const std::string& name, const std::string& sha256,
@@ -234,14 +246,6 @@ class TllmGenFmhaKernel {
234246
}
235247
}
236248

237-
static std::string getCubinPath() {
238-
const char* env_hash = std::getenv("FLASHINFER_CUBIN_ARTIFACTORY_HASH");
239-
std::string hash =
240-
env_hash ? std::string(env_hash) : "52e676342c67a3772e06f10b84600044c0c22b76";
241-
std::string cubin_path = hash + "/fmha/trtllm-gen/";
242-
return cubin_path;
243-
}
244-
245249
private:
246250
// Is it MLA generation kernel ?
247251
inline bool isMlaGenKernel(RunnerParams const& params) const {
@@ -539,7 +543,7 @@ class TllmGenFmhaKernel {
539543
};
540544
if (findModuleIter == mModules.end()) {
541545
// Load the module.
542-
std::string cubin_path = TllmGenFmhaKernel::getCubinPath() + kernelMeta.mFuncName;
546+
std::string cubin_path = tllm_gen_fmha_cubin_path + kernelMeta.mFuncName;
543547
std::string cubin = getCubin(cubin_path, kernelMeta.sha256);
544548
if (cubin.empty()) {
545549
throw std::runtime_error("Failed to load cubin for " + kernelName);
@@ -593,9 +597,8 @@ class TllmFmhaKernelFactory {
593597
std::lock_guard<std::mutex> lg(s_mutex);
594598

595599
if (!metainfo_loaded) {
596-
std::string metainfo_raw =
597-
getMetaInfo(TllmGenFmhaKernel::getCubinPath() + "flashInferMetaInfo",
598-
"8c5630020c0452fb1cd1ea7e3b8fdbb7bf94f71bd899ed5b704a490bdb4f7368", ".h");
600+
std::string metainfo_raw = getMetaInfo(tllm_gen_fmha_cubin_path + "flashInferMetaInfo",
601+
tllm_gen_fmha_metainfo_hash, ".h");
599602
metainfo = KernelType::KernelMeta::loadFromMetaInfoRaw(metainfo_raw);
600603
metainfo_loaded = true;
601604
}

0 commit comments

Comments
 (0)