23
23
from typing import Generator
24
24
import requests # type: ignore[import-untyped]
25
25
import shutil
26
+ import hashlib
26
27
27
28
# Create logger for artifacts module to avoid circular import with jit.core
28
29
logger = logging .getLogger ("flashinfer.artifacts" )
35
36
download_file ,
36
37
safe_urljoin ,
37
38
FLASHINFER_CUBIN_DIR ,
39
+ download_file ,
38
40
)
39
41
40
42
@@ -78,50 +80,88 @@ def get_available_cubin_files(
78
80
return tuple ()
79
81
80
82
81
- @dataclass (frozen = True )
82
83
class ArtifactPath :
83
- TRTLLM_GEN_FMHA : str = "7206d64e67f4c8949286246d6e2e07706af5d223 /fmha/trtllm-gen"
84
+ TRTLLM_GEN_FMHA : str = "a72d85b019dc125b9f711300cb989430f762f5a6 /fmha/trtllm-gen/ "
84
85
TRTLLM_GEN_BMM : str = (
85
- "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0 /batched_gemm-45beda1-ee6a802 "
86
+ "a72d85b019dc125b9f711300cb989430f762f5a6 /batched_gemm-145d1b1-9e1d49a/ "
86
87
)
87
88
TRTLLM_GEN_GEMM : str = (
88
- "037e528e719ec3456a7d7d654f26b805e44c63b1 /gemm-8704aa4 -f91dc9e"
89
+ "a72d85b019dc125b9f711300cb989430f762f5a6 /gemm-145d1b1 -f91dc9e/ "
89
90
)
90
- CUDNN_SDPA : str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn"
91
- DEEPGEMM : str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm"
92
-
91
+ CUDNN_SDPA : str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
92
+ DEEPGEMM : str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
93
93
94
+ # TODO: Should be deprecated
94
95
@dataclass (frozen = True )
95
96
class MetaInfoHash :
96
97
TRTLLM_GEN_FMHA : str = (
97
98
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
98
99
)
99
100
TRTLLM_GEN_BMM : str = (
100
- "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34 "
101
+ "9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4 "
101
102
)
102
103
DEEPGEMM : str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
103
104
TRTLLM_GEN_GEMM : str = (
104
- "0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba "
105
+ "7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67 "
105
106
)
106
107
107
108
108
- def get_cubin_file_list () -> Generator [str , None , None ]:
109
- base = FLASHINFER_CUBINS_REPOSITORY
109
+ # @dataclass(frozen=True)
110
+ class CheckSumHash :
111
+ TRTLLM_GEN_FMHA : str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
112
+ TRTLLM_GEN_BMM : str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
113
+ DEEPGEMM : str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
114
+ TRTLLM_GEN_GEMM : str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
110
115
111
- # The meta info header files first.
112
- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )
113
- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )
114
- yield safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )
116
+ map_checksums : [dict [str , str ]] = {
117
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "checksums.txt" ): TRTLLM_GEN_FMHA ,
118
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "checksums.txt" ): TRTLLM_GEN_BMM ,
119
+ safe_urljoin (ArtifactPath .DEEPGEMM , "checksums.txt" ): DEEPGEMM ,
120
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "checksums.txt" ): TRTLLM_GEN_GEMM ,
121
+ }
115
122
116
- # All the actual kernel cubin's.
117
- for kernel in [
123
+
124
+ def get_checksums (subdirs ):
125
+ checksums = {}
126
+ for subdir in subdirs :
127
+ uri = safe_urljoin (FLASHINFER_CUBINS_REPOSITORY , safe_urljoin (subdir , "checksums.txt" ))
128
+ checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin (subdir , "checksums.txt" )
129
+ download_file (uri , checksum_path )
130
+ with open (checksum_path , "r" ) as f :
131
+ for line in f :
132
+ sha256 , filename = line .strip ().split ()
133
+
134
+ # Distinguish between all meta info header files
135
+ if ".h" in filename :
136
+ filename = safe_urljoin (subdir , filename )
137
+ checksums [filename ] = sha256
138
+ return checksums
139
+
140
+
141
+ def get_subdir_file_list ():
142
+ base = FLASHINFER_CUBINS_REPOSITORY
143
+
144
+ cubin_dirs = [
118
145
ArtifactPath .TRTLLM_GEN_FMHA ,
119
146
ArtifactPath .TRTLLM_GEN_BMM ,
120
147
ArtifactPath .TRTLLM_GEN_GEMM ,
121
148
ArtifactPath .DEEPGEMM ,
122
- ]:
123
- for name in get_available_cubin_files (safe_urljoin (base , kernel )):
124
- yield safe_urljoin (kernel , name )
149
+ ]
150
+
151
+ # Get checksums of all files
152
+ checksums = get_checksums (cubin_dirs )
153
+
154
+ # The meta info header files first.
155
+ yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )])
156
+ yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )])
157
+ yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )])
158
+
159
+ # All the actual kernel cubin's.
160
+ for cubin_dir in cubin_dirs :
161
+ checksum_path = safe_urljoin (cubin_dir , "checksums.txt" )
162
+ yield (checksum_path , CheckSumHash .map_checksums [checksum_path ])
163
+ for name in get_available_cubin_files (safe_urljoin (base , cubin_dir )):
164
+ yield (safe_urljoin (cubin_dir , name ), checksums [name ])
125
165
126
166
127
167
def download_artifacts () -> None :
@@ -130,13 +170,11 @@ def download_artifacts() -> None:
130
170
# use a shared session to make use of HTTP keep-alive and reuse of
131
171
# HTTPS connections.
132
172
session = requests .Session ()
133
-
134
- cubin_files = list (get_cubin_file_list ())
173
+ cubin_files = list (get_subdir_file_list ())
135
174
num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
136
175
with tqdm_logging_redirect (
137
176
total = len (cubin_files ), desc = "Downloading cubins"
138
177
) as pbar :
139
-
140
178
def update_pbar_cb (_ ) -> None :
141
179
pbar .update (1 )
142
180
@@ -165,7 +203,7 @@ def get_artifacts_status() -> tuple[tuple[str, bool], ...]:
165
203
Check which cubins are already downloaded and return (num_downloaded, total).
166
204
Does not download any cubins.
167
205
"""
168
- cubin_files = get_cubin_file_list ()
206
+ cubin_files = get_subdir_file_list ()
169
207
170
208
def _check_file_status (file_name : str ) -> tuple [str , bool ]:
171
209
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
@@ -174,7 +212,7 @@ def _check_file_status(file_name: str) -> tuple[str, bool]:
174
212
exists = os .path .isfile (local_path )
175
213
return (file_name , exists )
176
214
177
- return tuple (_check_file_status (file_name ) for file_name in cubin_files )
215
+ return tuple (_check_file_status (file_name ) for file_name , _ in cubin_files )
178
216
179
217
180
218
def clear_cubin ():
0 commit comments