Skip to content

Commit 96d6da0

Browse files
committed
Refactors CUDA extension sources in setup.py to use glob for dynamic file inclusion
1 parent 77edcb0 commit 96d6da0

File tree

1 file changed

+6
-81
lines changed

1 file changed

+6
-81
lines changed

setup.py

Lines changed: 6 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -206,87 +206,12 @@ def append_nvcc_threads(nvcc_extra_args):
206206
ext_modules.append(
207207
CUDAExtension(
208208
name="flash_dmattn_cuda",
209-
sources=[
210-
"csrc/flash_dmattn/flash_api.cpp",
211-
# Forward kernels - regular
212-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu",
213-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu",
214-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
215-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
216-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
217-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
218-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
219-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
220-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
221-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
222-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
223-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
224-
# Forward kernels - causal
225-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu",
226-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu",
227-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu",
228-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu",
229-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu",
230-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu",
231-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu",
232-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu",
233-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu",
234-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu",
235-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu",
236-
"csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu",
237-
# Forward kernels - split
238-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu",
239-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu",
240-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu",
241-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu",
242-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu",
243-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu",
244-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu",
245-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu",
246-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu",
247-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu",
248-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu",
249-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu",
250-
# Forward kernels - split causal
251-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
252-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
253-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
254-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
255-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
256-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
257-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
258-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
259-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
260-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
261-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
262-
"csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
263-
# Backward kernels - regular
264-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu",
265-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu",
266-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
267-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
268-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
269-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
270-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
271-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
272-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
273-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
274-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
275-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
276-
# Backward kernels - causal
277-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu",
278-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu",
279-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu",
280-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu",
281-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu",
282-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu",
283-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu",
284-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu",
285-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu",
286-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu",
287-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu",
288-
"csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu",
289-
],
209+
sources=(
210+
[
211+
"csrc/flash_dmattn/flash_api.cpp",
212+
]
213+
+ sorted(glob.glob("csrc/flash_dmattn/src/instantiations/flash_*.cu"))
214+
),
290215
extra_compile_args={
291216
"cxx": compiler_c17_flag,
292217
"nvcc": append_nvcc_threads(nvcc_flags + cc_flag),

0 commit comments

Comments
 (0)