Skip to content

Commit 7c471b9

Browse files
committed
checksum check
1 parent fd1d506 commit 7c471b9

File tree

1 file changed

+77
-39
lines changed

1 file changed

+77
-39
lines changed

flashinfer/artifacts.py

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
from typing import Generator
2323
import requests # type: ignore[import-untyped]
2424
import shutil
25+
import hashlib
2526

2627
from .jit.core import logger
2728
from .jit.cubin_loader import (
2829
FLASHINFER_CUBINS_REPOSITORY,
2930
get_cubin,
3031
safe_urljoin,
3132
FLASHINFER_CUBIN_DIR,
33+
download_file,
3234
)
3335

3436

@@ -72,50 +74,88 @@ def get_available_cubin_files(
7274
return tuple()
7375

7476

75-
@dataclass(frozen=True)
7677
class ArtifactPath:
77-
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
78+
TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/"
7879
TRTLLM_GEN_BMM: str = (
79-
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802"
80+
"a72d85b019dc125b9f711300cb989430f762f5a6/batched_gemm-145d1b1-9e1d49a/"
8081
)
8182
TRTLLM_GEN_GEMM: str = (
82-
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e"
83+
"a72d85b019dc125b9f711300cb989430f762f5a6/gemm-145d1b1-f91dc9e/"
8384
)
84-
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn"
85-
DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm"
86-
85+
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
86+
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
8787

88+
# TODO: Should be deprecated
8889
@dataclass(frozen=True)
8990
class MetaInfoHash:
9091
TRTLLM_GEN_FMHA: str = (
9192
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
9293
)
9394
TRTLLM_GEN_BMM: str = (
94-
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
95+
"9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4"
9596
)
9697
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9798
TRTLLM_GEN_GEMM: str = (
98-
"0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba"
99+
"7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67"
99100
)
100101

101102

102-
def get_cubin_file_list() -> Generator[str, None, None]:
103-
base = FLASHINFER_CUBINS_REPOSITORY
103+
# @dataclass(frozen=True)
104+
class CheckSumHash:
105+
TRTLLM_GEN_FMHA: str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
106+
TRTLLM_GEN_BMM: str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
107+
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
108+
TRTLLM_GEN_GEMM: str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
104109

105-
# The meta info header files first.
106-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
107-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
108-
yield safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")
110+
map_checksums: [dict[str, str]] = {
111+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
112+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
113+
safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
114+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM,
115+
}
109116

110-
# All the actual kernel cubin's.
111-
for kernel in [
117+
118+
def get_checksums(subdirs):
119+
checksums = {}
120+
for subdir in subdirs:
121+
uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt"))
122+
checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt")
123+
download_file(uri, checksum_path)
124+
with open(checksum_path, "r") as f:
125+
for line in f:
126+
sha256, filename = line.strip().split()
127+
128+
# Distinguish between all meta info header files
129+
if ".h" in filename:
130+
filename = safe_urljoin(subdir, filename)
131+
checksums[filename] = sha256
132+
return checksums
133+
134+
135+
def get_subdir_file_list():
136+
base = FLASHINFER_CUBINS_REPOSITORY
137+
138+
cubin_dirs = [
112139
ArtifactPath.TRTLLM_GEN_FMHA,
113140
ArtifactPath.TRTLLM_GEN_BMM,
114141
ArtifactPath.TRTLLM_GEN_GEMM,
115142
ArtifactPath.DEEPGEMM,
116-
]:
117-
for name in get_available_cubin_files(safe_urljoin(base, kernel)):
118-
yield safe_urljoin(kernel, name)
143+
]
144+
145+
# Get checksums of all files
146+
checksums = get_checksums(cubin_dirs)
147+
148+
# The meta info header files first.
149+
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")])
150+
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")])
151+
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")])
152+
153+
# All the actual kernel cubin's.
154+
for cubin_dir in cubin_dirs:
155+
checksum_path = safe_urljoin(cubin_dir, "checksums.txt")
156+
yield (checksum_path, CheckSumHash.map_checksums[checksum_path])
157+
for name in get_available_cubin_files(safe_urljoin(base, cubin_dir)):
158+
yield (safe_urljoin(cubin_dir, name), checksums[name])
119159

120160

121161
def download_artifacts() -> None:
@@ -124,27 +164,25 @@ def download_artifacts() -> None:
124164
# use a shared session to make use of HTTP keep-alive and reuse of
125165
# HTTPS connections.
126166
session = requests.Session()
167+
cubin_files = list(get_subdir_file_list())
168+
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
169+
with tqdm_logging_redirect(
170+
total=len(cubin_files), desc="Downloading cubins"
171+
) as pbar:
127172

128-
with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"):
129-
cubin_files = list(get_cubin_file_list())
130-
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
131-
with tqdm_logging_redirect(
132-
total=len(cubin_files), desc="Downloading cubins"
133-
) as pbar:
134-
135-
def update_pbar_cb(_) -> None:
136-
pbar.update(1)
173+
def update_pbar_cb(_) -> None:
174+
pbar.update(1)
137175

138-
with ThreadPoolExecutor(num_threads) as pool:
139-
futures = []
140-
for name in cubin_files:
141-
fut = pool.submit(get_cubin, name, "", session)
142-
fut.add_done_callback(update_pbar_cb)
143-
futures.append(fut)
176+
with ThreadPoolExecutor(num_threads) as pool:
177+
futures = []
178+
for name, checksum in cubin_files:
179+
fut = pool.submit(get_cubin, name, checksum, session)
180+
fut.add_done_callback(update_pbar_cb)
181+
futures.append(fut)
144182

145-
results = [fut.result() for fut in as_completed(futures)]
183+
results = [fut.result() for fut in as_completed(futures)]
146184

147-
all_success = all(results)
185+
all_success = all(results)
148186
if not all_success:
149187
raise RuntimeError("Failed to download cubins")
150188

@@ -154,7 +192,7 @@ def get_artifacts_status() -> tuple[tuple[str, bool], ...]:
154192
Check which cubins are already downloaded and return (num_downloaded, total).
155193
Does not download any cubins.
156194
"""
157-
cubin_files = get_cubin_file_list()
195+
cubin_files = get_subdir_file_list()
158196

159197
def _check_file_status(file_name: str) -> tuple[str, bool]:
160198
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
@@ -163,7 +201,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]:
163201
exists = os.path.isfile(local_path)
164202
return (file_name, exists)
165203

166-
return tuple(_check_file_status(file_name) for file_name in cubin_files)
204+
return tuple(_check_file_status(file_name) for file_name, _ in cubin_files)
167205

168206

169207
def clear_cubin():

0 commit comments

Comments
 (0)