32
32
33
33
from .jit .cubin_loader import (
34
34
FLASHINFER_CUBINS_REPOSITORY ,
35
- download_file ,
36
35
safe_urljoin ,
37
36
FLASHINFER_CUBIN_DIR ,
38
37
download_file ,
38
+ verify_cubin ,
39
39
)
40
40
41
41
@@ -91,22 +91,11 @@ class ArtifactPath:
91
91
DEEPGEMM : str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
92
92
93
93
94
- # TODO (jimmyzho): Should be deprecated except DEEPGEMM
95
94
@dataclass (frozen = True )
96
95
class MetaInfoHash :
97
- TRTLLM_GEN_FMHA : str = (
98
- "2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
99
- )
100
- TRTLLM_GEN_BMM : str = (
101
- "9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4"
102
- )
103
96
DEEPGEMM : str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
104
- TRTLLM_GEN_GEMM : str = (
105
- "7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67"
106
- )
107
97
108
98
109
- # @dataclass(frozen=True)
110
99
class CheckSumHash :
111
100
TRTLLM_GEN_FMHA : str = (
112
101
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
@@ -197,12 +186,13 @@ def download_artifacts() -> None:
197
186
with tqdm_logging_redirect (
198
187
total = len (cubin_files ), desc = "Downloading cubins"
199
188
) as pbar :
189
+
200
190
def update_pbar_cb (_ ) -> None :
201
191
pbar .update (1 )
202
192
203
193
with ThreadPoolExecutor (num_threads ) as pool :
204
194
futures = []
205
- for name in cubin_files :
195
+ for name , _ in cubin_files :
206
196
source = safe_urljoin (FLASHINFER_CUBINS_REPOSITORY , name )
207
197
local_path = FLASHINFER_CUBIN_DIR / name
208
198
# Ensure parent directory exists
@@ -219,6 +209,12 @@ def update_pbar_cb(_) -> None:
219
209
if not all_success :
220
210
raise RuntimeError ("Failed to download cubins" )
221
211
212
+ # Check checksums of all downloaded cubins
213
+ for name , checksum in cubin_files :
214
+ local_path = FLASHINFER_CUBIN_DIR / name
215
+ if not verify_cubin (str (local_path ), checksum ):
216
+ raise RuntimeError ("Failed to download cubins: checksum mismatch" )
217
+
222
218
223
219
def get_artifacts_status () -> tuple [tuple [str , bool ], ...]:
224
220
"""
0 commit comments