@@ -107,9 +107,9 @@ def get_checksums(kernels):
107
107
return checksums
108
108
109
109
110
- def get_cubin_file_list ():
110
+ def get_subdir_file_list ():
111
111
base = FLASHINFER_CUBINS_REPOSITORY .rstrip ("/" )
112
- cubin_files = [
112
+ subdir_files = [
113
113
(
114
114
ArtifactPath .TRTLLM_GEN_FMHA + "include/flashInferMetaInfo" ,
115
115
".h" ,
@@ -135,13 +135,14 @@ def get_cubin_file_list():
135
135
checksums = get_checksums (kernels )
136
136
137
137
for kernel in kernels :
138
- cubin_files += [
138
+ subdir_files += [(kernel + "checksums" , ".txt" , None )]
139
+ subdir_files += [
139
140
(kernel + name , extension , checksums [kernel + name + extension ])
140
141
for name , extension in get_available_cubin_files (
141
142
urljoin (base + "/" , kernel )
142
143
)
143
144
]
144
- return cubin_files
145
+ return subdir_files
145
146
146
147
147
148
def download_artifacts ():
@@ -150,7 +151,7 @@ def download_artifacts():
150
151
# use a shared session to make use of HTTP keep-alive and reuse of
151
152
# HTTPS connections.
152
153
session = requests .Session ()
153
- cubin_files = get_cubin_file_list ()
154
+ cubin_files = get_subdir_file_list ()
154
155
num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
155
156
with tqdm_logging_redirect (
156
157
total = len (cubin_files ), desc = "Downloading cubins"
@@ -162,6 +163,8 @@ def update_pbar_cb(_) -> None:
162
163
with ThreadPoolExecutor (num_threads ) as pool :
163
164
futures = []
164
165
for name , extension , checksum in cubin_files :
166
+ if "checksums" in name :
167
+ continue
165
168
fut = pool .submit (get_cubin , name , checksum , extension , session )
166
169
fut .add_done_callback (update_pbar_cb )
167
170
futures .append (fut )
@@ -178,7 +181,7 @@ def get_artifacts_status():
178
181
Check which cubins are already downloaded and return (num_downloaded, total).
179
182
Does not download any cubins.
180
183
"""
181
- cubin_files = get_cubin_file_list ()
184
+ cubin_files = get_subdir_file_list ()
182
185
status = []
183
186
for name , extension in cubin_files :
184
187
# get_cubin stores cubins in FLASHINFER_CUBIN_DIR with the same relative path
0 commit comments