1
+ set (CMAKE_EXPORT_COMPILE_COMMANDS ON )
2
+ cmake_minimum_required (VERSION 3.18)
3
+
4
+ project (flash_attn LANGUAGES CXX CUDA)
5
+
6
+ set (CMAKE_JOB_POOLS cuda=6)
7
+ set (CMAKE_INSTALL_RPATH "$ORIGIN/nvidia/cuda_runtime/lib" )
8
+ # Make sure RPATH is used instead of RUNPATH
9
+ set (CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE )
10
+
11
+ # == Find dependencies ==
12
+ find_package (Python REQUIRED COMPONENTS Interpreter Development.Module)
13
+
14
+ execute_process (
15
+ COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir
16
+ OUTPUT_VARIABLE pybind11_DIR
17
+ OUTPUT_STRIP_TRAILING_WHITESPACE
18
+ )
19
+
20
+ find_package (pybind11 CONFIG REQUIRED)
21
+
22
+ # == Setup CUDA ==
23
+ string (REGEX REPLACE "--generate-code=arch=compute_[0-9]+,code=\\ [?compute_[0-9]+,sm_[0-9]+\\ ]?" ""
24
+ CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} " )
25
+ string (REGEX REPLACE "-gencode arch=compute_[0-9]+,code=sm_[0-9]+" ""
26
+ CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} " )
27
+
28
+ message (WARNING "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS} " )
29
+
30
+ # Set up ccache
31
+ find_program (CCACHE_PROGRAM ccache)
32
+ if (CCACHE_PROGRAM)
33
+ set (CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM} " )
34
+ message (STATUS "Using ccache: ${CCACHE_PROGRAM} " )
35
+ endif ()
36
+
37
+ # Options from environment variables
38
+ option (FLASH_ATTENTION_FORCE_BUILD "Force building from source" OFF )
39
+ option (FLASH_ATTENTION_SKIP_CUDA_BUILD "Skip CUDA build" OFF )
40
+ option (FLASH_ATTENTION_FORCE_CXX11_ABI "Force using C++11 ABI" OFF )
41
+
42
+ # CUDA handling
43
+ # Get CUDA architectures from environment or use default
44
+ if (DEFINED ENV{FLASH_ATTN_CUDA_ARCHS})
45
+ set (CMAKE_CUDA_ARCHITECTURES $ENV{FLASH_ATTN_CUDA_ARCHS} )
46
+ else ()
47
+ # set(CMAKE_CUDA_ARCHITECTURES "80;90;100;120")
48
+ set (CMAKE_CUDA_ARCHITECTURES "80" )
49
+ endif ()
50
+
51
+ find_package (CUDAToolkit REQUIRED)
52
+
53
+
54
+ # CUDA flags
55
+ set (CUDA_FLAGS
56
+ -O3
57
+ -std=c++20
58
+ --use_fast_math
59
+ --expt-relaxed-constexpr
60
+ --expt-extended-lambda
61
+ -U__CUDA_NO_HALF_OPERATORS__
62
+ -U__CUDA_NO_HALF_CONVERSIONS__
63
+ -U__CUDA_NO_HALF2_OPERATORS__
64
+ -U__CUDA_NO_BFLOAT16_CONVERSIONS__
65
+ )
66
+
67
+ # Collect source files
68
+ file (GLOB CUDA_SOURCES
69
+ "csrc/flash_attn/src/flash_fwd_hdim*.cu"
70
+ "csrc/flash_attn/src/flash_bwd_hdim*.cu"
71
+ "csrc/flash_attn/src/flash_fwd_split_hdim*.cu"
72
+ )
73
+
74
+ file (GLOB CC_SOURCES
75
+ "csrc/flash_attn/*.cpp"
76
+ )
77
+
78
+ # Create CUDA extension
79
+ pybind11_add_module(flash_api
80
+ ${CC_SOURCES}
81
+ ${CUDA_SOURCES}
82
+ )
83
+
84
+ set_property (TARGET flash_api PROPERTY JOB_POOL_COMPILE cuda)
85
+
86
+ target_compile_options (flash_api PRIVATE
87
+ $<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS} >
88
+ )
89
+
90
+ target_include_directories (flash_api PRIVATE
91
+ ${CMAKE_CURRENT_SOURCE_DIR} /csrc/flash_attn
92
+ ${CMAKE_CURRENT_SOURCE_DIR} /csrc/flash_attn/src
93
+ ${CMAKE_CURRENT_SOURCE_DIR} /csrc/cutlass/include
94
+ )
95
+
96
+ target_link_libraries (flash_api PRIVATE
97
+ CUDA::cudart
98
+ )
99
+
100
+ if (FLASH_ATTENTION_FORCE_CXX11_ABI)
101
+ target_compile_definitions (flash_api PRIVATE
102
+ _GLIBCXX_USE_CXX11_ABI=1
103
+ )
104
+ endif ()
105
+
106
+ # Installation
107
+ install (TARGETS flash_api
108
+ DESTINATION ${SKBUILD_PLATLIB_DIR} /flash_attn_jax_lib
109
+ )
110
+
111
+ install (DIRECTORY src/flash_attn_jax/
112
+ DESTINATION ${SKBUILD_PLATLIB_DIR} /flash_attn_jax
113
+ FILES_MATCHING PATTERN "*.py"
114
+ )
0 commit comments