27
27
FLASHINFER_CUBINS_REPOSITORY ,
28
28
get_cubin ,
29
29
FLASHINFER_CUBIN_DIR ,
30
+ download_file ,
30
31
)
31
32
32
33
@@ -69,44 +70,72 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
69
70
70
71
71
72
class ArtifactPath :
72
- TRTLLM_GEN_FMHA : str = "538f8e38ace07f701f61e26b138b2b8c70ce9e8e /fmha/trtllm-gen/"
73
+ TRTLLM_GEN_FMHA : str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /fmha/trtllm-gen/"
73
74
TRTLLM_GEN_BMM : str = (
74
- "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0 /batched_gemm-45beda1-ee6a802 /"
75
+ "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /batched_gemm-45beda1-7bdba93 /"
75
76
)
76
77
TRTLLM_GEN_GEMM : str = (
77
- "037e528e719ec3456a7d7d654f26b805e44c63b1 /gemm-8704aa4 -f91dc9e/"
78
+ "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /gemm-45beda1 -f91dc9e/"
78
79
)
79
- CUDNN_SDPA : str = "4c623163877c8fef5751c9c7a59940cd2baae02e /fmha/cudnn/"
80
- DEEPGEMM : str = "51d730202c9eef782f06ecc950005331d85c5d4b /deep-gemm/"
80
+ CUDNN_SDPA : str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /fmha/cudnn/"
81
+ DEEPGEMM : str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1 /deep-gemm/"
81
82
82
83
83
84
class MetaInfoHash :
84
85
TRTLLM_GEN_FMHA : str = (
85
- "71f06a8fc03d28cc94ee6fc180fb7e37256a9e1c30ab2a6c0bf20a2d97af3eff "
86
+ "875f50e8f466120b1a59b94397835b86fad785942b4036823230465bc618b919 "
86
87
)
87
88
TRTLLM_GEN_BMM : str = (
88
- "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34 "
89
+ "9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4 "
89
90
)
90
91
DEEPGEMM : str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
91
92
TRTLLM_GEN_GEMM : str = (
92
- "0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba "
93
+ "7d8ef4e6d89b6990e3e90a3d3a21e96918824d819f8f897a9bfd994925b9ea67 "
93
94
)
94
95
95
96
97
+ def get_checksums (kernels ):
98
+ checksums = {}
99
+ for kernel in kernels :
100
+ uri = FLASHINFER_CUBINS_REPOSITORY + "/" + (kernel + "checksums.txt" )
101
+ checksum_path = FLASHINFER_CUBIN_DIR / (kernel + "checksums.txt" )
102
+ download_file (uri , checksum_path )
103
+ with open (checksum_path , "r" ) as f :
104
+ for line in f :
105
+ sha256 , filename = line .strip ().split ()
106
+ checksums [kernel + filename ] = sha256
107
+ return checksums
108
+
109
+
96
110
def get_cubin_file_list ():
97
111
cubin_files = [
98
- (ArtifactPath .TRTLLM_GEN_FMHA + "include/flashInferMetaInfo" , ".h" ),
99
- (ArtifactPath .TRTLLM_GEN_GEMM + "include/flashinferMetaInfo" , ".h" ),
100
- (ArtifactPath .TRTLLM_GEN_BMM + "include/flashinferMetaInfo" , ".h" ),
112
+ (
113
+ ArtifactPath .TRTLLM_GEN_FMHA + "include/flashInferMetaInfo" ,
114
+ ".h" ,
115
+ MetaInfoHash .TRTLLM_GEN_FMHA ,
116
+ ),
117
+ (
118
+ ArtifactPath .TRTLLM_GEN_GEMM + "include/flashinferMetaInfo" ,
119
+ ".h" ,
120
+ MetaInfoHash .TRTLLM_GEN_GEMM ,
121
+ ),
122
+ (
123
+ ArtifactPath .TRTLLM_GEN_BMM + "include/flashinferMetaInfo" ,
124
+ ".h" ,
125
+ MetaInfoHash .TRTLLM_GEN_BMM ,
126
+ ),
101
127
]
102
- for kernel in [
128
+ kernels = [
103
129
ArtifactPath .TRTLLM_GEN_FMHA ,
104
- ArtifactPath .TRTLLM_GEN_BMM ,
105
130
ArtifactPath .TRTLLM_GEN_GEMM ,
131
+ ArtifactPath .TRTLLM_GEN_BMM ,
106
132
ArtifactPath .DEEPGEMM ,
107
- ]:
133
+ ]
134
+ checksums = get_checksums (kernels )
135
+
136
+ for kernel in kernels :
108
137
cubin_files += [
109
- (kernel + name , extension )
138
+ (kernel + name , extension , checksums [ kernel + name + extension ] )
110
139
for name , extension in get_available_cubin_files (
111
140
FLASHINFER_CUBINS_REPOSITORY + "/" + kernel
112
141
)
@@ -120,27 +149,25 @@ def download_artifacts():
120
149
# use a shared session to make use of HTTP keep-alive and reuse of
121
150
# HTTPS connections.
122
151
session = requests .Session ()
152
+ cubin_files = get_cubin_file_list ()
153
+ num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
154
+ with tqdm_logging_redirect (
155
+ total = len (cubin_files ), desc = "Downloading cubins"
156
+ ) as pbar :
123
157
124
- with temp_env_var ("FLASHINFER_CUBIN_CHECKSUM_DISABLED" , "1" ):
125
- cubin_files = get_cubin_file_list ()
126
- num_threads = int (os .environ .get ("FLASHINFER_CUBIN_DOWNLOAD_THREADS" , "4" ))
127
- with tqdm_logging_redirect (
128
- total = len (cubin_files ), desc = "Downloading cubins"
129
- ) as pbar :
130
-
131
- def update_pbar_cb (_ ) -> None :
132
- pbar .update (1 )
158
+ def update_pbar_cb (_ ) -> None :
159
+ pbar .update (1 )
133
160
134
- with ThreadPoolExecutor (num_threads ) as pool :
135
- futures = []
136
- for name , extension in cubin_files :
137
- fut = pool .submit (get_cubin , name , "" , extension , session )
138
- fut .add_done_callback (update_pbar_cb )
139
- futures .append (fut )
161
+ with ThreadPoolExecutor (num_threads ) as pool :
162
+ futures = []
163
+ for name , extension , checksum in cubin_files :
164
+ fut = pool .submit (get_cubin , name , checksum , extension , session )
165
+ fut .add_done_callback (update_pbar_cb )
166
+ futures .append (fut )
140
167
141
- results = [fut .result () for fut in as_completed (futures )]
168
+ results = [fut .result () for fut in as_completed (futures )]
142
169
143
- all_success = all (results )
170
+ all_success = all (results )
144
171
if not all_success :
145
172
raise RuntimeError ("Failed to download cubins" )
146
173
0 commit comments