Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aotriton
Submodule aotriton updated 178 files
137 changes: 93 additions & 44 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -278,67 +278,116 @@ else()
set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES})

if(USE_FUSED_ATTN_AOTRITON)
# This is for GPU kernel downloading
# The AOTriton C++ runtime will be built from ../../3rdparty/aotriton
# Hence there is no need to add multiple ROCM version here

set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton")
set(__AOTRITON_SUFFIX "_TEprivate")

if(NOT DEFINED AOTRITON_PATH)
# Install aotriton fused attn
# If AOTRITON_PATH is not provided, we proceed to build the runtime
# ourselves and either build or download the GPU kernels
if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(AOTRITON_NOIMAGE_MODE OFF)
else()
set(AOTRITON_NOIMAGE_MODE ON)
endif()

string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
set(__AOTRITON_VER "0.11.1b")
set(__AOTRITON_IMAGE_LIST
"amd-gfx942"
"amd-gfx950"
)
set(__AOTRITON_IMAGE_SHA256_LIST
"0a7bcee19d3bb6d548732248c3234f7b92736c2ab7a7aae65294b87a7fd64c06" # amd-gfx942
"c1ba3bfe84217fd67df3dd1f8b67c80a7f7b33d0ad4d74b41d6567036e032ace" # amd-gfx950
)
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
include(ExternalProject)
ExternalProject_Add(aotriton_external
LIST_SEPARATOR ","
SOURCE_DIR ${TE}/3rdparty/aotriton
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR}
-DGPU_TARGETS=${ARCH_LIST_COMMA_STR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"

# Download GPU kernels for a specific target
function(aotriton_download_image image project)
list(FIND __AOTRITON_IMAGE_LIST ${image} index)
list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_IMAGE_SHA256)

string(CONCAT __AOTRITON_FILE
"aotriton-${__AOTRITON_VER}-images-"
"${image}.tar.gz")
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")

# Set up directories
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton/download/${image})
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton/image/${image})

ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_IMAGE_SHA256}
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${__AOTRITON_EXTRACT_DIR}"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
)
message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.")
endfunction()

# Build the AOTriton runtime from source with custom suffix to avoid
# potential conflict with libaotriton as provided by PyTorch
function(aotriton_build_from_source noimage)
message(STATUS "No-image mode: ${noimage}.")
ExternalProject_Add(aotriton_external
LIST_SEPARATOR ","
SOURCE_DIR ${TE}/3rdparty/aotriton
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR}
-DGPU_TARGETS=${ARCH_LIST_COMMA_STR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${noimage}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
)
message(STATUS "Adding AOTriton library.")
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
)
endfunction()

add_library(aotriton INTERFACE)
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
message(STATUS "Building AOTriton from source.")
string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE})

# Download GPU kernels if needed
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(__AOTRITON_VER "0.10b")
set(__AOTRITON_SHA256 "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b")
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/"
"${__AOTRITON_VER}/aotriton-"
"${__AOTRITON_VER}-manylinux_2_28"
"_x86_64-rocm7.0"
"-shared.tar.gz")
set(aotriton_image_dirs)
foreach(X IN LISTS CMAKE_HIP_ARCHITECTURES)
list(APPEND aotriton_image_dirs "${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images/amd-${X}")
message(STATUS "Downloading AOTriton GPU Kernels.")
foreach(image ${__AOTRITON_IMAGE_LIST})
string(SUBSTRING ${image} 7 -1 gfx_pattern)
string(REPLACE "x" "." gfx_regex ${gfx_pattern})
foreach(target ${ARCH_LIST_COMMA_STR})
if(target MATCHES ${gfx_regex})
message(STATUS "Downloading AOTriton image ${image}.")
set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern})
aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET})
add_dependencies(aotriton_external ${__AOTRITON_DOWNLOAD_TARGET})
break()
endif()
endforeach()
endforeach()
set(aotriton_lib_install_dir "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images")
file(REMOVE_RECURSE ${aotriton_lib_install_dir})
file(MAKE_DIRECTORY ${aotriton_lib_install_dir})
ExternalProject_Add(aotriton_images
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
BUILD_ALWAYS TRUE
INSTALL_COMMAND cp -Ra ${aotriton_image_dirs} ${aotriton_lib_install_dir})
add_dependencies(aotriton aotriton_images)
else()
endif()
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE)

else()
# Use aotriton built during initial TE building/installation
# When only need rebuild TE library itself
Expand Down
6 changes: 5 additions & 1 deletion transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){
fused_attn_aotriton_fwd_qkvpacked(
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_QKV,
output_O, Aux_CTX_Tensors,
input_cu_seqlens,
Expand Down Expand Up @@ -576,6 +578,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_aotriton_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV,
output_O, Aux_CTX_Tensors,
Expand Down Expand Up @@ -759,6 +762,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_aotriton_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V,
output_O, Aux_CTX_Tensors,
Expand Down
Loading