@@ -206,87 +206,12 @@ def append_nvcc_threads(nvcc_extra_args):
206206 ext_modules .append (
207207 CUDAExtension (
208208 name = "flash_dmattn_cuda" ,
209- sources = [
210- "csrc/flash_dmattn/flash_api.cpp" ,
211- # Forward kernels - regular
212- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu" ,
213- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu" ,
214- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu" ,
215- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu" ,
216- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu" ,
217- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu" ,
218- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu" ,
219- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu" ,
220- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu" ,
221- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu" ,
222- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu" ,
223- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu" ,
224- # Forward kernels - causal
225- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu" ,
226- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu" ,
227- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu" ,
228- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu" ,
229- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu" ,
230- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu" ,
231- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu" ,
232- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu" ,
233- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu" ,
234- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu" ,
235- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu" ,
236- "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu" ,
237- # Forward kernels - split
238- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu" ,
239- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu" ,
240- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu" ,
241- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu" ,
242- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu" ,
243- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu" ,
244- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu" ,
245- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu" ,
246- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu" ,
247- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu" ,
248- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu" ,
249- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu" ,
250- # Forward kernels - split causal
251- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu" ,
252- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu" ,
253- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu" ,
254- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu" ,
255- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu" ,
256- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu" ,
257- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu" ,
258- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu" ,
259- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu" ,
260- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu" ,
261- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu" ,
262- "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu" ,
263- # Backward kernels - regular
264- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu" ,
265- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu" ,
266- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu" ,
267- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu" ,
268- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu" ,
269- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu" ,
270- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu" ,
271- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu" ,
272- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu" ,
273- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu" ,
274- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu" ,
275- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu" ,
276- # Backward kernels - causal
277- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu" ,
278- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu" ,
279- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu" ,
280- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu" ,
281- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu" ,
282- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu" ,
283- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu" ,
284- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu" ,
285- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu" ,
286- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu" ,
287- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu" ,
288- "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu" ,
289- ],
209+ sources = (
210+ [
211+ "csrc/flash_dmattn/flash_api.cpp" ,
212+ ]
213+ + sorted (glob .glob ("csrc/flash_dmattn/src/instantiations/flash_*.cu" ))
214+ ),
290215 extra_compile_args = {
291216 "cxx" : compiler_c17_flag ,
292217 "nvcc" : append_nvcc_threads (nvcc_flags + cc_flag ),
0 commit comments