11# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
22# SPDX-License-Identifier: MIT
33
4- #TODO: compile to a shared library
5- cmake_minimum_required (VERSION 3.28)
6- set (CMAKE_CXX_STANDARD 20)
7- #TODO: remove after figuring out how to install clang-scan-deps
8- set (CMAKE_CXX_SCAN_FOR_MODULES OFF )
4+ cmake_minimum_required (VERSION 3.21)
5+ set (CMAKE_CXX_STANDARD 17)
96project (ck_fused_attn LANGUAGES HIP CXX)
107
11- # remove files that should be regenerated
12- file (REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp ${CMAKE_CURRENT_BINARY_DIR} /gen_src/blob_list.txt)
138
14- # create gen_src and gen_src/tmp directories if needed
15- file (MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp)
9+ set (AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE" )
1610
1711set (__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR} /../../../3rdparty/aiter" )
12+ set (__AITER_TEST_DIR "${__AITER_SOURCE_DIR} /op_tests/cpp/mha" )
1813set (__CK_SOURCE_DIR "${__AITER_SOURCE_DIR} /3rdparty/composable_kernel" )
1914
2015# so far, there are only gfx942 and gfx950 v3 kernels
@@ -37,158 +32,41 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}")
3732list (JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR)
3833set (ENV{GPU_ARCHS} "${V3_ASM_ARCHS_STR} " )
3934
40- # generate v2 (CK) kernels
41- # fwd kernels list
42- execute_process (
43- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
44- --api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR} /gen_src/fwd_blob_list.txt --receipt 600
45- )
46- execute_process (
47- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
48- --api fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR} /gen_src/fwd_splitkv_blob_list.txt --receipt 600
49- )
50- execute_process (
51- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
52- --api batch_prefill --list_blobs ${CMAKE_CURRENT_BINARY_DIR} /gen_src/fwd_batch_prefill_blob_list.txt --receipt 600
53- )
54-
55- # bwd kernels list
56- execute_process (
57- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
58- --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR} /gen_src/bwd_blob_list.txt --receipt 600
59- )
60-
61- file (STRINGS ${CMAKE_CURRENT_BINARY_DIR} /gen_src/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
62- file (STRINGS ${CMAKE_CURRENT_BINARY_DIR} /gen_src/fwd_splitkv_blob_list.txt FMHA_FWD_SPLITKV_GEN_BLOBS)
63- file (STRINGS ${CMAKE_CURRENT_BINARY_DIR} /gen_src/fwd_batch_prefill_blob_list.txt FMHA_FWD_BATCH_PREFILL_GEN_BLOBS)
64- file (STRINGS ${CMAKE_CURRENT_BINARY_DIR} /gen_src/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
65-
66- # generate the actual fwd kernel cpp files
67- execute_process (
68- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
69- --api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp --receipt 600
70- )
71-
72- execute_process (
73- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
74- --api fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp --receipt 600
75- )
76-
77- execute_process (
78- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
79- --api batch_prefill --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp --receipt 600
80- )
81-
82- # generate the aiter fwd interface cpp file
83- execute_process (
84- COMMAND python3 ${__AITER_SOURCE_DIR} /csrc/cpp_itfs/mha_fwd_generate.py
85- --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp --receipt 5
86- )
87-
88- # generate the actual bwd kernel cpp files
89- execute_process (
90- COMMAND python3 ${__CK_SOURCE_DIR} /example/ck_tile/01_fmha/generate.py
91- --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp --receipt 600
92- )
93-
94- # generate the aiter bwd interface cpp file
95- execute_process (
96- COMMAND python3 ${__AITER_SOURCE_DIR} /csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py
97- --filter *@*_ndeterministic@*_nbias*_dropout*_ndeterministic* --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp
98- )
99-
100- execute_process (
101- COMMAND python3 ${__AITER_SOURCE_DIR} /csrc/cpp_itfs/mha_bwd_generate.py
102- --receipt 3 --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp
103- )
104-
105- # generate fwd/bwd v3 kernels for each requested rocm arch
106- foreach (CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS)
107- execute_process (
108- COMMAND python3 ${__AITER_SOURCE_DIR} /hsa/${CK_TARGET_ARCH} /fmha_v3_fwd/codegen.py
109- --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp
110- )
35+ if (NOT DEFINED AITER_MHA_PATH)
36+ # delete the existing aiter/jit/build dir for a clean build
37+ file (REMOVE_RECURSE "${__AITER_SOURCE_DIR} /aiter/jit/build" )
38+ # compile the libmha_fwd.so and libmha_bwd.so
39+ set (ENV{AITER_LOG_MORE} 1)
40+ # fp32 to bf16 cvt env still required for MI300X
41+ set (ENV{CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT} ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} )
11142 execute_process (
112- COMMAND python3 ${__AITER_SOURCE_DIR} /hsa/${CK_TARGET_ARCH} /fmha_v3_bwd/codegen.py
113- --output_dir ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp
43+ COMMAND python3 ${__AITER_TEST_DIR} /compile.py
11444 )
115- endforeach ()
45+ # libmha_fwd.so and libmha_bwd.so will be under 3rdparty/aiter/op_tests/cpp/mha
46+ set (__AITER_MHA_PATH ${__AITER_TEST_DIR} )
47+ else ()
48+ # use pre-built libmha_fwd.so libmha_bwd.so
49+ set (__AITER_MHA_PATH ${AITER_MHA_PATH} )
50+ endif ()
11651
11752set (ck_fused_attn_SOURCES)
11853list (APPEND ck_fused_attn_SOURCES
11954 src/ck_fused_attn_fwd.cpp
12055 src/ck_fused_attn_bwd.cpp
12156 src/ck_fused_attn_utils.cpp)
12257
123- foreach (blob ${FMHA_FWD_GEN_BLOBS} )
124- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${blob} )
125- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
126- endforeach ()
127- list (APPEND ck_fused_attn_SOURCES ${FMHA_FWD_GEN_BLOBS} )
128-
129- foreach (blob ${FMHA_FWD_SPLITKV_GEN_BLOBS} )
130- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${blob} )
131- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
132- endforeach ()
133- list (APPEND ck_fused_attn_SOURCES ${FMHA_FWD_SPLITKV_GEN_BLOBS} )
134-
135- foreach (blob ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS} )
136- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${blob} )
137- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
138- endforeach ()
139- list (APPEND ck_fused_attn_SOURCES ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS} )
140-
141- foreach (blob ${FMHA_BWD_GEN_BLOBS} )
142- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${blob} )
143- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
144- endforeach ()
145- list (APPEND ck_fused_attn_SOURCES ${FMHA_BWD_GEN_BLOBS} )
146-
147- # add generated cpp files into ck_fused_attn_sources
148- set (MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR} /gen_src/mha_bwd.cpp" )
149- set (MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR} /gen_src/mha_fwd.cpp" )
150-
151- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${MHA_BWD_SRC} )
152- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${MHA_BWD_SRC} ONLY_IF_DIFFERENT)
153-
154- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${MHA_FWD_SRC} )
155- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${MHA_FWD_SRC} ONLY_IF_DIFFERENT)
156-
157- list (APPEND ck_fused_attn_SOURCES ${MHA_BWD_SRC} ${MHA_FWD_SRC} )
158-
159- foreach (CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS)
160- set (ASM_MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR} /gen_src/asm_fmha_fwd_v3_${CK_TARGET_ARCH} .cpp" )
161- set (ASM_MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR} /gen_src/asm_fmha_bwd_v3_${CK_TARGET_ARCH} .cpp" )
162-
163- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${ASM_MHA_BWD_SRC} )
164- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${ASM_MHA_BWD_SRC} ONLY_IF_DIFFERENT)
165-
166- file (RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR} /gen_src ${ASM_MHA_FWD_SRC} )
167- file (COPY_FILE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp/${blob_path} ${ASM_MHA_FWD_SRC} ONLY_IF_DIFFERENT)
168- list (APPEND ck_fused_attn_SOURCES ${ASM_MHA_BWD_SRC} ${ASM_MHA_FWD_SRC} )
169- endforeach ()
170-
171- # remove all previously generated temporary files
172- file (REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR} /gen_src/tmp)
173-
17458message (STATUS "Found the following fused attention files:" )
17559foreach (file ${ck_fused_attn_SOURCES} )
17660 message (STATUS " ${file} " )
17761endforeach ()
17862
179- add_library (ck_fused_attn STATIC ${ck_fused_attn_SOURCES} )
63+ add_library (ck_fused_attn SHARED ${ck_fused_attn_SOURCES} )
18064set (CK_FUSED_ATTN_COMPILE_OPTIONS)
18165list (APPEND CK_FUSED_ATTN_COMPILE_OPTIONS
182- -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -DCK_TILE_FMHA_FWD_SPLITKV_API=1-DCK_TILE_FMHA_FWD_APPENDKV_API=0
183- -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}
184- -fgpu-flush-denormals-to-zero -ftemplate-backtrace-limit=0 -fPIC
185- -Wno-undefined-func-template -Wno-float-equal -Wno-gnu-line-marker -Wunused-variable -Wuninitialized
186- "SHELL:-mllvm -enable-post-misched=0" "SHELL:-mllvm -amdgpu-early-inline-all=true"
187- "SHELL:-mllvm -amdgpu-function-calls=false" "SHELL:-mllvm -amdgpu-coerce-illegal-types=1"
188- "SHELL:-mllvm --amdgpu-kernarg-preload-count=16" )
66+ -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} )
18967
190- foreach (CK_TARGET_ARCH IN LISTS CMAKE_HIP_ARCHITECTURES )
191- list (APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${CK_TARGET_ARCH } )
68+ foreach (ARCH IN LISTS V3_ASM_ARCHS )
69+ list (APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${ARCH } )
19270endforeach ()
19371
19472set (CK_INCLUDE_DIR "${__CK_SOURCE_DIR} /include" )
@@ -216,18 +94,22 @@ target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE
21694target_include_directories (ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR} )
21795
21896find_package (hip)
219- list (APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64)
97+ list (APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH} /libmha_fwd.so ${__AITER_MHA_PATH} /libmha_bwd.so )
22098target_link_libraries (ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS} )
22199target_compile_options (ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS} )
100+ set_target_properties (ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN" )
222101
102+ install (FILES ${__AITER_MHA_PATH} /libmha_fwd.so ${__AITER_MHA_PATH} /libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX} /${AITER_MHA_INSTALL_PREFIX} /lib)
103+ install (TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX} /${AITER_MHA_INSTALL_PREFIX} /lib)
223104# copy v3 kernels to destination
224105foreach (ARCH IN LISTS V3_ASM_ARCHS)
225106 install (DIRECTORY
226107 ${__AITER_SOURCE_DIR} /hsa/${ARCH} /fmha_v3_fwd
227- DESTINATION ${CMAKE_INSTALL_PREFIX} /transformer_engine /aiter/${ARCH} /
108+ DESTINATION ${CMAKE_INSTALL_PREFIX} /${AITER_MHA_INSTALL_PREFIX} /lib /aiter/${ARCH} /
228109 PATTERN "codegen.py" EXCLUDE )
229110 install (DIRECTORY
230111 ${__AITER_SOURCE_DIR} /hsa/${ARCH} /fmha_v3_bwd
231- DESTINATION ${CMAKE_INSTALL_PREFIX} /transformer_engine /aiter/${ARCH} /
112+ DESTINATION ${CMAKE_INSTALL_PREFIX} /${AITER_MHA_INSTALL_PREFIX} /lib /aiter/${ARCH} /
232113 PATTERN "codegen.py" EXCLUDE )
233114endforeach ()
115+
0 commit comments