Skip to content

Commit 38d4940

Browse files
committed
checksum check
1 parent 352b3ad commit 38d4940

File tree

1 file changed

+63
-25
lines changed

1 file changed

+63
-25
lines changed

flashinfer/artifacts.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import Generator
2424
import requests # type: ignore[import-untyped]
2525
import shutil
26+
import hashlib
2627

2728
# Create logger for artifacts module to avoid circular import with jit.core
2829
logger = logging.getLogger("flashinfer.artifacts")
@@ -35,6 +36,7 @@
3536
download_file,
3637
safe_urljoin,
3738
FLASHINFER_CUBIN_DIR,
39+
download_file,
3840
)
3941

4042

@@ -78,50 +80,88 @@ def get_available_cubin_files(
7880
return tuple()
7981

8082

81-
@dataclass(frozen=True)
8283
class ArtifactPath:
83-
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
84+
TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/"
8485
TRTLLM_GEN_BMM: str = (
85-
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802"
86+
"a72d85b019dc125b9f711300cb989430f762f5a6/batched_gemm-145d1b1-9e1d49a/"
8687
)
8788
TRTLLM_GEN_GEMM: str = (
88-
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e"
89+
"a72d85b019dc125b9f711300cb989430f762f5a6/gemm-145d1b1-f91dc9e/"
8990
)
90-
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn"
91-
DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm"
92-
91+
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
92+
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
9393

94+
# TODO: Should be deprecated
9495
@dataclass(frozen=True)
9596
class MetaInfoHash:
9697
TRTLLM_GEN_FMHA: str = (
9798
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
9899
)
99100
TRTLLM_GEN_BMM: str = (
100-
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
101+
"9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4"
101102
)
102103
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
103104
TRTLLM_GEN_GEMM: str = (
104-
"0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba"
105+
"7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67"
105106
)
106107

107108

108-
def get_cubin_file_list() -> Generator[str, None, None]:
109-
base = FLASHINFER_CUBINS_REPOSITORY
109+
# @dataclass(frozen=True)
110+
class CheckSumHash:
111+
TRTLLM_GEN_FMHA: str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
112+
TRTLLM_GEN_BMM: str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
113+
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
114+
TRTLLM_GEN_GEMM: str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
110115

111-
# The meta info header files first.
112-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
113-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
114-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")
116+
map_checksums: [dict[str, str]] = {
117+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
118+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
119+
safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
120+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM,
121+
}
115122

116-
# All the actual kernel cubin's.
117-
for kernel in [
123+
124+
def get_checksums(subdirs):
125+
checksums = {}
126+
for subdir in subdirs:
127+
uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt"))
128+
checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt")
129+
download_file(uri, checksum_path)
130+
with open(checksum_path, "r") as f:
131+
for line in f:
132+
sha256, filename = line.strip().split()
133+
134+
# Distinguish between all meta info header files
135+
if ".h" in filename:
136+
filename = safe_urljoin(subdir, filename)
137+
checksums[filename] = sha256
138+
return checksums
139+
140+
141+
def get_subdir_file_list():
142+
base = FLASHINFER_CUBINS_REPOSITORY
143+
144+
cubin_dirs = [
118145
ArtifactPath.TRTLLM_GEN_FMHA,
119146
ArtifactPath.TRTLLM_GEN_BMM,
120147
ArtifactPath.TRTLLM_GEN_GEMM,
121148
ArtifactPath.DEEPGEMM,
122-
]:
123-
for name in get_available_cubin_files(safe_urljoin(base, kernel)):
124-
yield safe_urljoin(kernel, name)
149+
]
150+
151+
# Get checksums of all files
152+
checksums = get_checksums(cubin_dirs)
153+
154+
# The meta info header files first.
155+
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")])
156+
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")])
157+
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")])
158+
159+
# All the actual kernel cubin's.
160+
for cubin_dir in cubin_dirs:
161+
checksum_path = safe_urljoin(cubin_dir, "checksums.txt")
162+
yield (checksum_path, CheckSumHash.map_checksums[checksum_path])
163+
for name in get_available_cubin_files(safe_urljoin(base, cubin_dir)):
164+
yield (safe_urljoin(cubin_dir, name), checksums[name])
125165

126166

127167
def download_artifacts() -> None:
@@ -130,13 +170,11 @@ def download_artifacts() -> None:
130170
# use a shared session to make use of HTTP keep-alive and reuse of
131171
# HTTPS connections.
132172
session = requests.Session()
133-
134-
cubin_files = list(get_cubin_file_list())
173+
cubin_files = list(get_subdir_file_list())
135174
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
136175
with tqdm_logging_redirect(
137176
total=len(cubin_files), desc="Downloading cubins"
138177
) as pbar:
139-
140178
def update_pbar_cb(_) -> None:
141179
pbar.update(1)
142180

@@ -165,7 +203,7 @@ def get_artifacts_status() -> tuple[tuple[str, bool], ...]:
165203
Check which cubins are already downloaded and return (num_downloaded, total).
166204
Does not download any cubins.
167205
"""
168-
cubin_files = get_cubin_file_list()
206+
cubin_files = get_subdir_file_list()
169207

170208
def _check_file_status(file_name: str) -> tuple[str, bool]:
171209
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
@@ -174,7 +212,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]:
174212
exists = os.path.isfile(local_path)
175213
return (file_name, exists)
176214

177-
return tuple(_check_file_status(file_name) for file_name in cubin_files)
215+
return tuple(_check_file_status(file_name) for file_name, _ in cubin_files)
178216

179217

180218
def clear_cubin():

0 commit comments

Comments
 (0)