27
27
FLASHINFER_CUBINS_REPOSITORY ,
28
28
get_cubin ,
29
29
FLASHINFER_CUBIN_DIR ,
30
+ download_file ,
30
31
)
31
32
32
33
@@ -71,43 +72,71 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
71
72
class ArtifactPath :
72
73
TRTLLM_GEN_FMHA : str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen/"
73
74
TRTLLM_GEN_BMM : str = (
74
- "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0 /batched_gemm-45beda1-ee6a802 /"
75
+ "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /batched_gemm-45beda1-7bdba93 /"
75
76
)
76
77
TRTLLM_GEN_GEMM : str = (
77
- "037e528e719ec3456a7d7d654f26b805e44c63b1 /gemm-8704aa4 -f91dc9e/"
78
+ "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /gemm-45beda1 -f91dc9e/"
78
79
)
79
- CUDNN_SDPA : str = "4c623163877c8fef5751c9c7a59940cd2baae02e /fmha/cudnn/"
80
- DEEPGEMM : str = "51d730202c9eef782f06ecc950005331d85c5d4b /deep-gemm/"
80
+ CUDNN_SDPA : str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /fmha/cudnn/"
81
+ DEEPGEMM : str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /deep-gemm/"
81
82
82
83
83
84
class MetaInfoHash :
84
85
TRTLLM_GEN_FMHA : str = (
85
86
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
86
87
)
87
88
TRTLLM_GEN_BMM : str = (
88
- "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34 "
89
+ "9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4 "
89
90
)
90
91
DEEPGEMM : str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
91
92
TRTLLM_GEN_GEMM : str = (
92
- "0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba "
93
+ "7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67 "
93
94
)
94
95
95
96
97
+ def get_checksums (kernels ):
98
+ checksums = {}
99
+ for kernel in kernels :
100
+ uri = FLASHINFER_CUBINS_REPOSITORY + "/" + (kernel + "checksums.txt" )
101
+ checksum_path = FLASHINFER_CUBIN_DIR / (kernel + "checksums.txt" )
102
+ download_file (uri , checksum_path )
103
+ with open (checksum_path , "r" ) as f :
104
+ for line in f :
105
+ sha256 , filename = line .strip ().split ()
106
+ checksums [kernel + filename ] = sha256
107
+ return checksums
108
+
109
+
96
110
def get_cubin_file_list ():
97
111
base = FLASHINFER_CUBINS_REPOSITORY .rstrip ("/" )
98
112
cubin_files = [
99
- (ArtifactPath .TRTLLM_GEN_FMHA + "include/flashInferMetaInfo" , ".h" ),
100
- (ArtifactPath .TRTLLM_GEN_GEMM + "include/flashinferMetaInfo" , ".h" ),
101
- (ArtifactPath .TRTLLM_GEN_BMM + "include/flashinferMetaInfo" , ".h" ),
113
+ (
114
+ ArtifactPath .TRTLLM_GEN_FMHA + "include/flashInferMetaInfo" ,
115
+ ".h" ,
116
+ MetaInfoHash .TRTLLM_GEN_FMHA ,
117
+ ),
118
+ (
119
+ ArtifactPath .TRTLLM_GEN_GEMM + "include/flashinferMetaInfo" ,
120
+ ".h" ,
121
+ MetaInfoHash .TRTLLM_GEN_GEMM ,
122
+ ),
123
+ (
124
+ ArtifactPath .TRTLLM_GEN_BMM + "include/flashinferMetaInfo" ,
125
+ ".h" ,
126
+ MetaInfoHash .TRTLLM_GEN_BMM ,
127
+ ),
102
128
]
103
- for kernel in [
129
+ kernels = [
104
130
ArtifactPath .TRTLLM_GEN_FMHA ,
105
- ArtifactPath .TRTLLM_GEN_BMM ,
106
131
ArtifactPath .TRTLLM_GEN_GEMM ,
132
+ ArtifactPath .TRTLLM_GEN_BMM ,
107
133
ArtifactPath .DEEPGEMM ,
108
- ]:
134
+ ]
135
+ checksums = get_checksums (kernels )
136
+
137
+ for kernel in kernels :
109
138
cubin_files += [
110
- (kernel + name , extension )
139
+ (kernel + name , extension , checksums [ kernel + name + extension ] )
111
140
for name , extension in get_available_cubin_files (
112
141
urljoin (base + "/" , kernel )
113
142
)
@@ -121,27 +150,25 @@ def download_artifacts():
121
150
# use a shared session to make use of HTTP keep-alive and reuse of
122
151
# HTTPS connections.
123
152
session = requests .Session ()
153
+ cubin_files = get_cubin_file_list ()
154
+ num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
155
+ with tqdm_logging_redirect (
156
+ total = len (cubin_files ), desc = "Downloading cubins"
157
+ ) as pbar :
124
158
125
- with temp_env_var ("FLASHINFER_CUBIN_CHECKSUM_DISABLED" , "1" ):
126
- cubin_files = get_cubin_file_list ()
127
- num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
128
- with tqdm_logging_redirect (
129
- total = len (cubin_files ), desc = "Downloading cubins"
130
- ) as pbar :
131
-
132
- def update_pbar_cb (_ ) -> None :
133
- pbar .update (1 )
159
+ def update_pbar_cb (_ ) -> None :
160
+ pbar .update (1 )
134
161
135
- with ThreadPoolExecutor (num_threads ) as pool :
136
- futures = []
137
- for name , extension in cubin_files :
138
- fut = pool .submit (get_cubin , name , "" , extension , session )
139
- fut .add_done_callback (update_pbar_cb )
140
- futures .append (fut )
162
+ with ThreadPoolExecutor (num_threads ) as pool :
163
+ futures = []
164
+ for name , extension , checksum in cubin_files :
165
+ fut = pool .submit (get_cubin , name , checksum , extension , session )
166
+ fut .add_done_callback (update_pbar_cb )
167
+ futures .append (fut )
141
168
142
- results = [fut .result () for fut in as_completed (futures )]
169
+ results = [fut .result () for fut in as_completed (futures )]
143
170
144
- all_success = all (results )
171
+ all_success = all (results )
145
172
if not all_success :
146
173
raise RuntimeError ("Failed to download cubins" )
147
174
0 commit comments