22
22
from typing import Generator
23
23
import requests # type: ignore[import-untyped]
24
24
import shutil
25
+ import hashlib
25
26
26
27
from .jit .core import logger
27
28
from .jit .cubin_loader import (
28
29
FLASHINFER_CUBINS_REPOSITORY ,
29
30
get_cubin ,
30
31
safe_urljoin ,
31
32
FLASHINFER_CUBIN_DIR ,
33
+ download_file ,
32
34
)
33
35
34
36
@@ -72,50 +74,88 @@ def get_available_cubin_files(
72
74
return tuple ()
73
75
74
76
75
- @dataclass (frozen = True )
76
77
class ArtifactPath :
77
- TRTLLM_GEN_FMHA : str = "7206d64e67f4c8949286246d6e2e07706af5d223 /fmha/trtllm-gen"
78
+ TRTLLM_GEN_FMHA : str = "a72d85b019dc125b9f711300cb989430f762f5a6 /fmha/trtllm-gen/ "
78
79
TRTLLM_GEN_BMM : str = (
79
- "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0 /batched_gemm-45beda1-ee6a802 "
80
+ "a72d85b019dc125b9f711300cb989430f762f5a6 /batched_gemm-145d1b1-9e1d49a/ "
80
81
)
81
82
TRTLLM_GEN_GEMM : str = (
82
- "037e528e719ec3456a7d7d654f26b805e44c63b1 /gemm-8704aa4 -f91dc9e"
83
+ "a72d85b019dc125b9f711300cb989430f762f5a6 /gemm-145d1b1 -f91dc9e/ "
83
84
)
84
- CUDNN_SDPA : str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn"
85
- DEEPGEMM : str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm"
86
-
85
+ CUDNN_SDPA : str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
86
+ DEEPGEMM : str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
87
87
88
+ # TODO: Should be deprecated
88
89
@dataclass (frozen = True )
89
90
class MetaInfoHash :
90
91
TRTLLM_GEN_FMHA : str = (
91
92
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
92
93
)
93
94
TRTLLM_GEN_BMM : str = (
94
- "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34 "
95
+ "9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4 "
95
96
)
96
97
DEEPGEMM : str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
97
98
TRTLLM_GEN_GEMM : str = (
98
- "0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba "
99
+ "7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67 "
99
100
)
100
101
101
102
102
- def get_cubin_file_list () -> Generator [str , None , None ]:
103
- base = FLASHINFER_CUBINS_REPOSITORY
103
+ # @dataclass(frozen=True)
104
+ class CheckSumHash :
105
+ TRTLLM_GEN_FMHA : str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
106
+ TRTLLM_GEN_BMM : str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
107
+ DEEPGEMM : str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
108
+ TRTLLM_GEN_GEMM : str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
104
109
105
- # The meta info header files first.
106
- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )
107
- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )
108
- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )
110
+ map_checksums : [dict [str , str ]] = {
111
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "checksums.txt" ): TRTLLM_GEN_FMHA ,
112
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "checksums.txt" ): TRTLLM_GEN_BMM ,
113
+ safe_urljoin (ArtifactPath .DEEPGEMM , "checksums.txt" ): DEEPGEMM ,
114
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "checksums.txt" ): TRTLLM_GEN_GEMM ,
115
+ }
109
116
110
- # All the actual kernel cubin's.
111
- for kernel in [
117
+
118
+ def get_checksums (subdirs ):
119
+ checksums = {}
120
+ for subdir in subdirs :
121
+ uri = safe_urljoin (FLASHINFER_CUBINS_REPOSITORY , safe_urljoin (subdir , "checksums.txt" ))
122
+ checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin (subdir , "checksums.txt" )
123
+ download_file (uri , checksum_path )
124
+ with open (checksum_path , "r" ) as f :
125
+ for line in f :
126
+ sha256 , filename = line .strip ().split ()
127
+
128
+ # Distinguish between all meta info header files
129
+ if ".h" in filename :
130
+ filename = safe_urljoin (subdir , filename )
131
+ checksums [filename ] = sha256
132
+ return checksums
133
+
134
+
135
+ def get_subdir_file_list ():
136
+ base = FLASHINFER_CUBINS_REPOSITORY
137
+
138
+ cubin_dirs = [
112
139
ArtifactPath .TRTLLM_GEN_FMHA ,
113
140
ArtifactPath .TRTLLM_GEN_BMM ,
114
141
ArtifactPath .TRTLLM_GEN_GEMM ,
115
142
ArtifactPath .DEEPGEMM ,
116
- ]:
117
- for name in get_available_cubin_files (safe_urljoin (base , kernel )):
118
- yield safe_urljoin (kernel , name )
143
+ ]
144
+
145
+ # Get checksums of all files
146
+ checksums = get_checksums (cubin_dirs )
147
+
148
+ # The meta info header files first.
149
+ yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )])
150
+ yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )])
151
+ yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )])
152
+
153
+ # All the actual kernel cubin's.
154
+ for cubin_dir in cubin_dirs :
155
+ checksum_path = safe_urljoin (cubin_dir , "checksums.txt" )
156
+ yield (checksum_path , CheckSumHash .map_checksums [checksum_path ])
157
+ for name in get_available_cubin_files (safe_urljoin (base , cubin_dir )):
158
+ yield (safe_urljoin (cubin_dir , name ), checksums [name ])
119
159
120
160
121
161
def download_artifacts () -> None :
@@ -124,27 +164,25 @@ def download_artifacts() -> None:
124
164
# use a shared session to make use of HTTP keep-alive and reuse of
125
165
# HTTPS connections.
126
166
session = requests .Session ()
167
+ cubin_files = list (get_subdir_file_list ())
168
+ num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
169
+ with tqdm_logging_redirect (
170
+ total = len (cubin_files ), desc = "Downloading cubins"
171
+ ) as pbar :
127
172
128
- with temp_env_var ("FLASHINFER_CUBIN_CHECKSUM_DISABLED" , "1" ):
129
- cubin_files = list (get_cubin_file_list ())
130
- num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
131
- with tqdm_logging_redirect (
132
- total = len (cubin_files ), desc = "Downloading cubins"
133
- ) as pbar :
134
-
135
- def update_pbar_cb (_ ) -> None :
136
- pbar .update (1 )
173
+ def update_pbar_cb (_ ) -> None :
174
+ pbar .update (1 )
137
175
138
- with ThreadPoolExecutor (num_threads ) as pool :
139
- futures = []
140
- for name in cubin_files :
141
- fut = pool .submit (get_cubin , name , "" , session )
142
- fut .add_done_callback (update_pbar_cb )
143
- futures .append (fut )
176
+ with ThreadPoolExecutor (num_threads ) as pool :
177
+ futures = []
178
+ for name , checksum in cubin_files :
179
+ fut = pool .submit (get_cubin , name , checksum , session )
180
+ fut .add_done_callback (update_pbar_cb )
181
+ futures .append (fut )
144
182
145
- results = [fut .result () for fut in as_completed (futures )]
183
+ results = [fut .result () for fut in as_completed (futures )]
146
184
147
- all_success = all (results )
185
+ all_success = all (results )
148
186
if not all_success :
149
187
raise RuntimeError ("Failed to download cubins" )
150
188
@@ -154,7 +192,7 @@ def get_artifacts_status() -> tuple[tuple[str, bool], ...]:
154
192
Check which cubins are already downloaded and return (num_downloaded, total).
155
193
Does not download any cubins.
156
194
"""
157
- cubin_files = get_cubin_file_list ()
195
+ cubin_files = get_subdir_file_list ()
158
196
159
197
def _check_file_status (file_name : str ) -> tuple [str , bool ]:
160
198
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
@@ -163,7 +201,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]:
163
201
exists = os .path .isfile (local_path )
164
202
return (file_name , exists )
165
203
166
- return tuple (_check_file_status (file_name ) for file_name in cubin_files )
204
+ return tuple (_check_file_status (file_name ) for file_name , _ in cubin_files )
167
205
168
206
169
207
def clear_cubin ():
0 commit comments