Skip to content

Commit a662e76

Browse files
committed
checksum check
1 parent 50319b2 commit a662e76

File tree

1 file changed

+59
-32
lines changed

1 file changed

+59
-32
lines changed

flashinfer/artifacts.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
FLASHINFER_CUBINS_REPOSITORY,
2828
get_cubin,
2929
FLASHINFER_CUBIN_DIR,
30+
download_file,
3031
)
3132

3233

@@ -69,44 +70,72 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
6970

7071

7172
class ArtifactPath:
72-
TRTLLM_GEN_FMHA: str = "538f8e38ace07f701f61e26b138b2b8c70ce9e8e/fmha/trtllm-gen/"
73+
TRTLLM_GEN_FMHA: str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/fmha/trtllm-gen/"
7374
TRTLLM_GEN_BMM: str = (
74-
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802/"
75+
"9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/batched_gemm-45beda1-7bdba93/"
7576
)
7677
TRTLLM_GEN_GEMM: str = (
77-
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e/"
78+
"9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/gemm-45beda1-f91dc9e/"
7879
)
79-
CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn/"
80-
DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm/"
80+
CUDNN_SDPA: str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/fmha/cudnn/"
81+
DEEPGEMM: str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/deep-gemm/"
8182

8283

8384
class MetaInfoHash:
8485
TRTLLM_GEN_FMHA: str = (
85-
"71f06a8fc03d28cc94ee6fc180fb7e37256a9e1c30ab2a6c0bf20a2d97af3eff"
86+
"875f50e8f466120b1a59b94397835b86fad785942b4036823230465bc618b919"
8687
)
8788
TRTLLM_GEN_BMM: str = (
88-
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
89+
"9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4"
8990
)
9091
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9192
TRTLLM_GEN_GEMM: str = (
92-
"0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba"
93+
"7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67"
9394
)
9495

9596

97+
def get_checksums(kernels):
98+
checksums = {}
99+
for kernel in kernels:
100+
uri = FLASHINFER_CUBINS_REPOSITORY + "/" + (kernel + "checksums.txt")
101+
checksum_path = FLASHINFER_CUBIN_DIR / (kernel + "checksums.txt")
102+
download_file(uri, checksum_path)
103+
with open(checksum_path, "r") as f:
104+
for line in f:
105+
sha256, filename = line.strip().split()
106+
checksums[kernel + filename] = sha256
107+
return checksums
108+
109+
96110
def get_cubin_file_list():
97111
cubin_files = [
98-
(ArtifactPath.TRTLLM_GEN_FMHA + "include/flashInferMetaInfo", ".h"),
99-
(ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo", ".h"),
100-
(ArtifactPath.TRTLLM_GEN_BMM + "include/flashinferMetaInfo", ".h"),
112+
(
113+
ArtifactPath.TRTLLM_GEN_FMHA + "include/flashInferMetaInfo",
114+
".h",
115+
MetaInfoHash.TRTLLM_GEN_FMHA,
116+
),
117+
(
118+
ArtifactPath.TRTLLM_GEN_GEMM + "include/flashinferMetaInfo",
119+
".h",
120+
MetaInfoHash.TRTLLM_GEN_GEMM,
121+
),
122+
(
123+
ArtifactPath.TRTLLM_GEN_BMM + "include/flashinferMetaInfo",
124+
".h",
125+
MetaInfoHash.TRTLLM_GEN_BMM,
126+
),
101127
]
102-
for kernel in [
128+
kernels = [
103129
ArtifactPath.TRTLLM_GEN_FMHA,
104-
ArtifactPath.TRTLLM_GEN_BMM,
105130
ArtifactPath.TRTLLM_GEN_GEMM,
131+
ArtifactPath.TRTLLM_GEN_BMM,
106132
ArtifactPath.DEEPGEMM,
107-
]:
133+
]
134+
checksums = get_checksums(kernels)
135+
136+
for kernel in kernels:
108137
cubin_files += [
109-
(kernel + name, extension)
138+
(kernel + name, extension, checksums[kernel + name + extension])
110139
for name, extension in get_available_cubin_files(
111140
FLASHINFER_CUBINS_REPOSITORY + "/" + kernel
112141
)
@@ -120,27 +149,25 @@ def download_artifacts():
120149
# use a shared session to make use of HTTP keep-alive and reuse of
121150
# HTTPS connections.
122151
session = requests.Session()
152+
cubin_files = get_cubin_file_list()
153+
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
154+
with tqdm_logging_redirect(
155+
total=len(cubin_files), desc="Downloading cubins"
156+
) as pbar:
123157

124-
with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"):
125-
cubin_files = get_cubin_file_list()
126-
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
127-
with tqdm_logging_redirect(
128-
total=len(cubin_files), desc="Downloading cubins"
129-
) as pbar:
130-
131-
def update_pbar_cb(_) -> None:
132-
pbar.update(1)
158+
def update_pbar_cb(_) -> None:
159+
pbar.update(1)
133160

134-
with ThreadPoolExecutor(num_threads) as pool:
135-
futures = []
136-
for name, extension in cubin_files:
137-
fut = pool.submit(get_cubin, name, "", extension, session)
138-
fut.add_done_callback(update_pbar_cb)
139-
futures.append(fut)
161+
with ThreadPoolExecutor(num_threads) as pool:
162+
futures = []
163+
for name, extension, checksum in cubin_files:
164+
fut = pool.submit(get_cubin, name, checksum, extension, session)
165+
fut.add_done_callback(update_pbar_cb)
166+
futures.append(fut)
140167

141-
results = [fut.result() for fut in as_completed(futures)]
168+
results = [fut.result() for fut in as_completed(futures)]
142169

143-
all_success = all(results)
170+
all_success = all(results)
144171
if not all_success:
145172
raise RuntimeError("Failed to download cubins")
146173

0 commit comments

Comments
 (0)