Skip to content

Commit b08a1ed

Browse files
ipanfilowangye805
andauthored
Bring back aiter solib with aiter update (#327)
* AITER solib with commit fc3c0420 * [ROCm] api call fix and disable v3 fwd with swa (#331) * [ROCm] update aiter commit with gfx950 fix and swa fwd fix --------- Co-authored-by: Ye Wang <[email protected]>
1 parent 1834247 commit b08a1ed

File tree

8 files changed

+55
-198
lines changed

8 files changed

+55
-198
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,5 @@ compile_commands.json
5252
**/profiler_outputs/
5353
**/times.csv
5454
tensor_dumps/
55-
aiter/
5655
transformer_engine/build_info.txt
5756
transformer_engine/common/util/hip_nvml.*
58-
transformer_engine/aiter/

3rdparty/aiter

Submodule aiter updated 273 files

setup.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434

3535
from setuptools.command.build_ext import build_ext as BuildExtension
36-
from setuptools.command.develop import develop as _develop
3736

3837
os.environ["NVTE_PROJECT_BUILDING"] = "1"
3938

@@ -48,26 +47,6 @@
4847
if not rocm_build():
4948
archs = cuda_archs()
5049

51-
# A custom develop command only used for ROCm builds
52-
class develop(_develop):
53-
def run(self):
54-
super().run()
55-
if (
56-
int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and
57-
int(os.getenv("NVTE_FUSED_ATTN", "1"))
58-
):
59-
# Ensure that the AITER ASM kernels are properly available at runtime
60-
# by creating a symlink to them. This is only necessary for editable
61-
# mode since our C++ code assumes the AITER ASM kernel paths relative
62-
# to trasnformer_engine.so, which is different in editable installs.
63-
project_dir = Path(__file__).parent
64-
asm_src_dir = project_dir / 'transformer_engine' / 'aiter'
65-
# Must be synced with
66-
# TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
67-
asm_target_dir = project_dir / 'aiter'
68-
if asm_src_dir.is_dir() and not asm_target_dir.is_dir():
69-
asm_target_dir.symlink_to(asm_src_dir)
70-
7150
class TimedBdist(bdist_wheel):
7251
"""Helper class to measure build time"""
7352

@@ -89,7 +68,7 @@ def setup_common_extension() -> CMakeExtension:
8968
cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}")
9069
if os.getenv("NVTE_CK_FUSED_ATTN_PATH"):
9170
ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH"))
92-
cmake_flags.append(f"-DCK_FUSED_ATTN_PATH={ck_path}")
71+
cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}")
9372
if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
9473
cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF")
9574
if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
@@ -192,14 +171,14 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
192171
with open("README.rst", encoding="utf-8") as f:
193172
long_description = f.read()
194173

195-
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
196174
# Settings for building top level empty package for dependency management.
197175
if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
198176
assert bool(
199177
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
200178
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
201179
te_cuda_vers = "rocm" if rocm_build() else "cu12"
202180
ext_modules = []
181+
cmdclass = {}
203182
package_data = {}
204183
include_package_data = False
205184
setup_requires = []
@@ -211,8 +190,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
211190
else:
212191
setup_requires, install_requires, test_requires = setup_requirements()
213192
ext_modules = [setup_common_extension()]
214-
if rocm_build():
215-
cmdclass["develop"] = develop
193+
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
216194
package_data = {"": ["VERSION.txt"]}
217195
include_package_data = True
218196
extras_require = {"test": test_requires}
@@ -255,7 +233,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
255233
long_description=long_description,
256234
long_description_content_type="text/x-rst",
257235
ext_modules=ext_modules,
258-
cmdclass=cmdclass,
236+
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
259237
python_requires=">=3.8, <3.13",
260238
classifiers=[
261239
"Programming Language :: Python :: 3.8",

transformer_engine/common/CMakeLists.txt

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,7 @@ else()
351351
endif()
352352

353353
if(USE_FUSED_ATTN_CK)
354-
if(NOT DEFINED CK_FUSED_ATTN_PATH)
355-
set(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} CACHE STRING "ck float to bf16 conversion rounding")
356-
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
357-
else()
358-
# Use CK built during initial TE building/installation
359-
# When only need rebuild TE library itself
360-
unset(CK_FUSED_ATTN_LIB CACHE)
361-
find_library(CK_FUSED_ATTN_LIB NAMES ck_fused_attn PATHS ${CK_FUSED_ATTN_PATH}/lib REQUIRED NO_DEFAULT_PATH)
362-
add_library( ck_fused_attn STATIC IMPORTED )
363-
set_target_properties( ck_fused_attn PROPERTIES IMPORTED_LOCATION ${CK_FUSED_ATTN_LIB} )
364-
target_include_directories(ck_fused_attn INTERFACE ${CK_FUSED_ATTN_PATH}/include)
365-
endif()
354+
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
366355
endif()
367356

368357
find_package(hip)
Lines changed: 29 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
22
# SPDX-License-Identifier: MIT
33

4-
#TODO: compile to a shared library
5-
cmake_minimum_required(VERSION 3.28)
6-
set(CMAKE_CXX_STANDARD 20)
7-
#TODO: remove after figuring out how to install clang-scan-deps
8-
set(CMAKE_CXX_SCAN_FOR_MODULES OFF)
4+
cmake_minimum_required(VERSION 3.21)
5+
set(CMAKE_CXX_STANDARD 17)
96
project(ck_fused_attn LANGUAGES HIP CXX)
107

11-
# remove files that should be regenerated
12-
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp ${CMAKE_CURRENT_BINARY_DIR}/gen_src/blob_list.txt)
138

14-
# create gen_src and gen_src/tmp directories if needed
15-
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp)
9+
set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE")
1610

1711
set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter")
12+
set(__AITER_TEST_DIR "${__AITER_SOURCE_DIR}/op_tests/cpp/mha")
1813
set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel")
1914

2015
# so far, there are only gfx942 and gfx950 v3 kernels
@@ -37,158 +32,41 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}")
3732
list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR)
3833
set(ENV{GPU_ARCHS} "${V3_ASM_ARCHS_STR}")
3934

40-
# generate v2 (CK) kernels
41-
# fwd kernels list
42-
execute_process(
43-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
44-
--api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt --receipt 600
45-
)
46-
execute_process(
47-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
48-
--api fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt --receipt 600
49-
)
50-
execute_process(
51-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
52-
--api batch_prefill --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt --receipt 600
53-
)
54-
55-
# bwd kernels list
56-
execute_process(
57-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
58-
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt --receipt 600
59-
)
60-
61-
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
62-
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt FMHA_FWD_SPLITKV_GEN_BLOBS)
63-
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt FMHA_FWD_BATCH_PREFILL_GEN_BLOBS)
64-
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
65-
66-
# generate the actual fwd kernel cpp files
67-
execute_process(
68-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
69-
--api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
70-
)
71-
72-
execute_process(
73-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
74-
--api fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
75-
)
76-
77-
execute_process(
78-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
79-
--api batch_prefill --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
80-
)
81-
82-
# generate the aiter fwd interface cpp file
83-
execute_process(
84-
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_fwd_generate.py
85-
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 5
86-
)
87-
88-
# generate the actual bwd kernel cpp files
89-
execute_process(
90-
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
91-
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
92-
)
93-
94-
# generate the aiter bwd interface cpp file
95-
execute_process(
96-
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py
97-
--filter *@*_ndeterministic@*_nbias*_dropout*_ndeterministic* --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
98-
)
99-
100-
execute_process(
101-
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_bwd_generate.py
102-
--receipt 3 --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
103-
)
104-
105-
# generate fwd/bwd v3 kernels for each requested rocm arch
106-
foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS)
107-
execute_process(
108-
COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_fwd/codegen.py
109-
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
110-
)
35+
if(NOT DEFINED AITER_MHA_PATH)
36+
# delete the existing aiter/jit/build dir for a clean build
37+
file(REMOVE_RECURSE "${__AITER_SOURCE_DIR}/aiter/jit/build")
38+
# compile the libmha_fwd.so and libmha_bwd.so
39+
set(ENV{AITER_LOG_MORE} 1)
40+
# fp32 to bf16 cvt env still required for MI300X
41+
set(ENV{CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT} ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT})
11142
execute_process(
112-
COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_bwd/codegen.py
113-
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
43+
COMMAND python3 ${__AITER_TEST_DIR}/compile.py
11444
)
115-
endforeach()
45+
# libmha_fwd.so and libmha_bwd.so will be under 3rdparty/aiter/op_tests/cpp/mha
46+
set(__AITER_MHA_PATH ${__AITER_TEST_DIR})
47+
else()
48+
# use pre-built libmha_fwd.so libmha_bwd.so
49+
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
50+
endif()
11651

11752
set(ck_fused_attn_SOURCES)
11853
list(APPEND ck_fused_attn_SOURCES
11954
src/ck_fused_attn_fwd.cpp
12055
src/ck_fused_attn_bwd.cpp
12156
src/ck_fused_attn_utils.cpp)
12257

123-
foreach(blob ${FMHA_FWD_GEN_BLOBS})
124-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
125-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
126-
endforeach()
127-
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_GEN_BLOBS})
128-
129-
foreach(blob ${FMHA_FWD_SPLITKV_GEN_BLOBS})
130-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
131-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
132-
endforeach()
133-
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_SPLITKV_GEN_BLOBS})
134-
135-
foreach(blob ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS})
136-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
137-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
138-
endforeach()
139-
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS})
140-
141-
foreach(blob ${FMHA_BWD_GEN_BLOBS})
142-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
143-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
144-
endforeach()
145-
list(APPEND ck_fused_attn_SOURCES ${FMHA_BWD_GEN_BLOBS})
146-
147-
# add generated cpp files into ck_fused_attn_sources
148-
set(MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_bwd.cpp")
149-
set(MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_fwd.cpp")
150-
151-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_BWD_SRC})
152-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_BWD_SRC} ONLY_IF_DIFFERENT)
153-
154-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_FWD_SRC})
155-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_FWD_SRC} ONLY_IF_DIFFERENT)
156-
157-
list(APPEND ck_fused_attn_SOURCES ${MHA_BWD_SRC} ${MHA_FWD_SRC})
158-
159-
foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS)
160-
set(ASM_MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_fwd_v3_${CK_TARGET_ARCH}.cpp")
161-
set(ASM_MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_bwd_v3_${CK_TARGET_ARCH}.cpp")
162-
163-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_BWD_SRC})
164-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_BWD_SRC} ONLY_IF_DIFFERENT)
165-
166-
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_FWD_SRC})
167-
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_FWD_SRC} ONLY_IF_DIFFERENT)
168-
list(APPEND ck_fused_attn_SOURCES ${ASM_MHA_BWD_SRC} ${ASM_MHA_FWD_SRC})
169-
endforeach()
170-
171-
# remove all previously generated temporary files
172-
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp)
173-
17458
message(STATUS "Found the following fused attention files:")
17559
foreach(file ${ck_fused_attn_SOURCES})
17660
message(STATUS " ${file}")
17761
endforeach()
17862

179-
add_library(ck_fused_attn STATIC ${ck_fused_attn_SOURCES})
63+
add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES})
18064
set(CK_FUSED_ATTN_COMPILE_OPTIONS)
18165
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS
182-
-DCK_TILE_FMHA_FWD_FAST_EXP2=1 -DCK_TILE_FMHA_FWD_SPLITKV_API=1-DCK_TILE_FMHA_FWD_APPENDKV_API=0
183-
-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}
184-
-fgpu-flush-denormals-to-zero -ftemplate-backtrace-limit=0 -fPIC
185-
-Wno-undefined-func-template -Wno-float-equal -Wno-gnu-line-marker -Wunused-variable -Wuninitialized
186-
"SHELL:-mllvm -enable-post-misched=0" "SHELL:-mllvm -amdgpu-early-inline-all=true"
187-
"SHELL:-mllvm -amdgpu-function-calls=false" "SHELL:-mllvm -amdgpu-coerce-illegal-types=1"
188-
"SHELL:-mllvm --amdgpu-kernarg-preload-count=16")
66+
-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT})
18967

190-
foreach(CK_TARGET_ARCH IN LISTS CMAKE_HIP_ARCHITECTURES)
191-
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${CK_TARGET_ARCH})
68+
foreach(ARCH IN LISTS V3_ASM_ARCHS)
69+
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${ARCH})
19270
endforeach()
19371

19472
set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include")
@@ -216,18 +94,22 @@ target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE
21694
target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR})
21795

21896
find_package(hip)
219-
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64)
97+
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so)
22098
target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS})
22199
target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS})
100+
set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")
222101

102+
install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
103+
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
223104
# copy v3 kernels to destination
224105
foreach(ARCH IN LISTS V3_ASM_ARCHS)
225106
install(DIRECTORY
226107
${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_fwd
227-
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/
108+
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
228109
PATTERN "codegen.py" EXCLUDE)
229110
install(DIRECTORY
230111
${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_bwd
231-
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/
112+
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
232113
PATTERN "codegen.py" EXCLUDE)
233114
endforeach()
115+

transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -920,8 +920,8 @@ hipError_t ck_attn_varlen_bwd(
920920
cu_seqlen_q_ptr,//cu_seqlen_q
921921
cu_seqlen_kv_ptr,//cu_seqlen_kv
922922
nullptr, /* seqlen_k_ptr */
923-
0, //seqlen_q, unused in group mode
924-
0, //seqlen_kv, unused in group mode
923+
max_seqlen_q, //seqlen_q, unused in group mode
924+
max_seqlen_k, //seqlen_kv, unused in group mode
925925
batch,
926926
max_seqlen_q,
927927
max_seqlen_k,

transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,13 @@ hipError_t ck_attn_fwd(
209209
nullptr,//rand_val_ptr
210210
lse_ptr,
211211
o_ptr,
212-
nullptr,//cu_seqlen_q
213-
nullptr,//cu_seqlen_kv
214-
nullptr, /* seqlen_k_ptr */
212+
nullptr, //cu_seqlen_q
213+
nullptr, //cu_seqlen_kv
214+
nullptr, //seqstart_q_ptr
215+
nullptr, //seqstart_k_ptr
216+
nullptr, //seqlen_k_ptr
217+
nullptr, //seqstart_padded_q_ptr
218+
nullptr, //seqstart_padded_k_ptr
215219
max_seqlen_q,
216220
max_seqlen_k,
217221
batch,
@@ -308,6 +312,7 @@ hipError_t ck_attn_varlen_fwd(
308312
ck_tile::index_t nhead_k = hg;
309313
ck_tile::index_t hdim_v = d_v;
310314
ck_tile::index_t max_seqlen_q = s_q;
315+
ck_tile::index_t max_seqlen_kv = s_kv;
311316

312317
float scale_s = scaling_factor;
313318
float scale_p = 1.f;
@@ -379,11 +384,15 @@ hipError_t ck_attn_varlen_fwd(
379384
nullptr,//rand_val_ptr
380385
lse_thd_ptr,
381386
o_ptr,
382-
cu_seqlen_q_ptr,//cu_seqlen_q
383-
cu_seqlen_kv_ptr,//cu_seqlen_kv
384-
nullptr, /* seqlen_k_ptr */
385-
0, //seqlen_q, unused in group mode
386-
0, //seqlen_kv, unused in group mode
387+
nullptr, //cu_seqlen_q
388+
nullptr, //cu_seqlen_kv
389+
cu_seqlen_q_ptr, //seqstart_q_ptr
390+
cu_seqlen_kv_ptr, //seqstart_k_ptr
391+
nullptr, //seqlen_k_ptr
392+
nullptr, //seqstart_padded_q_ptr
393+
nullptr, //seqstart_padded_k_ptr
394+
max_seqlen_q, //seqlen_q, unused in group mode
395+
max_seqlen_kv, //seqlen_kv, unused in group mode
387396
batch,
388397
max_seqlen_q,
389398
hdim_q,

0 commit comments

Comments
 (0)