Skip to content

Commit 2489184

Browse files
committed
checksum check
1 parent 107f735 commit 2489184

File tree

1 file changed

+57
-30
lines changed

1 file changed

+57
-30
lines changed

flashinfer/artifacts.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
FLASHINFER_CUBINS_REPOSITORY,
2828
get_cubin,
2929
FLASHINFER_CUBIN_DIR,
30+
download_file,
3031
)
3132

3233

@@ -71,43 +72,71 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
7172
class ArtifactPath:
7273
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen/"
7374
TRTLLM_GEN_BMM: str = (
74-
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802/"
75+
"9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/batched_gemm-45beda1-7bdba93/"
7576
)
7677
TRTLLM_GEN_GEMM: str = (
77-
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e/"
78+
"9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/gemm-45beda1-f91dc9e/"
7879
)
79-
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
80-
DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm/"
80+
CUDNN_SDPA: str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/fmha/cudnn/"
81+
DEEPGEMM: str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/deep-gemm/"
8182

8283

8384
class MetaInfoHash:
8485
TRTLLM_GEN_FMHA: str = (
8586
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
8687
)
8788
TRTLLM_GEN_BMM: str = (
88-
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
89+
"9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4"
8990
)
9091
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9192
TRTLLM_GEN_GEMM: str = (
92-
"0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba"
93+
"7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67"
9394
)
9495

9596

97+
def get_checksums(kernels):
98+
checksums = {}
99+
for kernel in kernels:
100+
uri = FLASHINFER_CUBINS_REPOSITORY + "/" + (kernel + "checksums.txt")
101+
checksum_path = FLASHINFER_CUBIN_DIR / (kernel + "checksums.txt")
102+
download_file(uri, checksum_path)
103+
with open(checksum_path, "r") as f:
104+
for line in f:
105+
sha256, filename = line.strip().split()
106+
checksums[kernel + filename] = sha256
107+
return checksums
108+
109+
96110
def get_cubin_file_list():
97111
base = FLASHINFER_CUBINS_REPOSITORY.rstrip("/")
98112
cubin_files = [
99-
(ArtifactPath.TRTLLM_GEN_FMHA + "include/flashInferMetaInfo", ".h"),
100-
(ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo", ".h"),
101-
(ArtifactPath.TRTLLM_GEN_BMM + "include/flashinferMetaInfo", ".h"),
113+
(
114+
ArtifactPath.TRTLLM_GEN_FMHA + "include/flashInferMetaInfo",
115+
".h",
116+
MetaInfoHash.TRTLLM_GEN_FMHA,
117+
),
118+
(
119+
ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo",
120+
".h",
121+
MetaInfoHash.TRTLLM_GEN_GEMM,
122+
),
123+
(
124+
ArtifactPath.TRTLLM_GEN_BMM + "include/flashinferMetaInfo",
125+
".h",
126+
MetaInfoHash.TRTLLM_GEN_BMM,
127+
),
102128
]
103-
for kernel in [
129+
kernels = [
104130
ArtifactPath.TRTLLM_GEN_FMHA,
105-
ArtifactPath.TRTLLM_GEN_BMM,
106131
ArtifactPath.TRTLLM_GEN_GEMM,
132+
ArtifactPath.TRTLLM_GEN_BMM,
107133
ArtifactPath.DEEPGEMM,
108-
]:
134+
]
135+
checksums = get_checksums(kernels)
136+
137+
for kernel in kernels:
109138
cubin_files += [
110-
(kernel + name, extension)
139+
(kernel + name, extension, checksums[kernel + name + extension])
111140
for name, extension in get_available_cubin_files(
112141
urljoin(base + "/", kernel)
113142
)
@@ -121,27 +150,25 @@ def download_artifacts():
121150
# use a shared session to make use of HTTP keep-alive and reuse of
122151
# HTTPS connections.
123152
session = requests.Session()
153+
cubin_files = get_cubin_file_list()
154+
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
155+
with tqdm_logging_redirect(
156+
total=len(cubin_files), desc="Downloading cubins"
157+
) as pbar:
124158

125-
with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"):
126-
cubin_files = get_cubin_file_list()
127-
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
128-
with tqdm_logging_redirect(
129-
total=len(cubin_files), desc="Downloading cubins"
130-
) as pbar:
131-
132-
def update_pbar_cb(_) -> None:
133-
pbar.update(1)
159+
def update_pbar_cb(_) -> None:
160+
pbar.update(1)
134161

135-
with ThreadPoolExecutor(num_threads) as pool:
136-
futures = []
137-
for name, extension in cubin_files:
138-
fut = pool.submit(get_cubin, name, "", extension, session)
139-
fut.add_done_callback(update_pbar_cb)
140-
futures.append(fut)
162+
with ThreadPoolExecutor(num_threads) as pool:
163+
futures = []
164+
for name, extension, checksum in cubin_files:
165+
fut = pool.submit(get_cubin, name, checksum, extension, session)
166+
fut.add_done_callback(update_pbar_cb)
167+
futures.append(fut)
141168

142-
results = [fut.result() for fut in as_completed(futures)]
169+
results = [fut.result() for fut in as_completed(futures)]
143170

144-
all_success = all(results)
171+
all_success = all(results)
145172
if not all_success:
146173
raise RuntimeError("Failed to download cubins")
147174

0 commit comments

Comments
 (0)