Skip to content

Commit 0f8289d

Browse files
committed
add checksum to files list
1 parent ed99415 commit 0f8289d

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

flashinfer/artifacts.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ def get_checksums(kernels):
107107
return checksums
108108

109109

110-
def get_cubin_file_list():
110+
def get_subdir_file_list():
111111
base = FLASHINFER_CUBINS_REPOSITORY.rstrip("/")
112-
cubin_files = [
112+
subdir_files = [
113113
(
114114
ArtifactPath.TRTLLM_GEN_FMHA + "include/flashInferMetaInfo",
115115
".h",
@@ -135,13 +135,14 @@ def get_cubin_file_list():
135135
checksums = get_checksums(kernels)
136136

137137
for kernel in kernels:
138-
cubin_files += [
138+
subdir_files += [(kernel + "checksums", ".txt", None)]
139+
subdir_files += [
139140
(kernel + name, extension, checksums[kernel + name + extension])
140141
for name, extension in get_available_cubin_files(
141142
urljoin(base + "/", kernel)
142143
)
143144
]
144-
return cubin_files
145+
return subdir_files
145146

146147

147148
def download_artifacts():
@@ -150,7 +151,7 @@ def download_artifacts():
150151
# use a shared session to make use of HTTP keep-alive and reuse of
151152
# HTTPS connections.
152153
session = requests.Session()
153-
cubin_files = get_cubin_file_list()
154+
cubin_files = get_subdir_file_list()
154155
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
155156
with tqdm_logging_redirect(
156157
total=len(cubin_files), desc="Downloading cubins"
@@ -162,6 +163,8 @@ def update_pbar_cb(_) -> None:
162163
with ThreadPoolExecutor(num_threads) as pool:
163164
futures = []
164165
for name, extension, checksum in cubin_files:
166+
if "checksums" in name:
167+
continue
165168
fut = pool.submit(get_cubin, name, checksum, extension, session)
166169
fut.add_done_callback(update_pbar_cb)
167170
futures.append(fut)
@@ -178,7 +181,7 @@ def get_artifacts_status():
178181
Check which cubins are already downloaded and return (num_downloaded, total).
179182
Does not download any cubins.
180183
"""
181-
cubin_files = get_cubin_file_list()
184+
cubin_files = get_subdir_file_list()
182185
status = []
183186
for name, extension in cubin_files:
184187
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path

flashinfer/jit/cubin_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def load_cubin(cubin_path, sha256) -> bytes:
125125
actual_sha = m.hexdigest()
126126
if sha256 == actual_sha:
127127
return cubin
128+
if "checksums" in cubin_path: # checksum file isn't checked
129+
return cubin
128130
logger.warning(
129131
f"sha256 mismatch (expected {sha256} actual {actual_sha}) for {cubin_path}"
130132
)

0 commit comments

Comments
 (0)