Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aotriton
Submodule aotriton updated 146 files
145 changes: 99 additions & 46 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -278,67 +278,120 @@ else()
set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES})

if(USE_FUSED_ATTN_AOTRITON)
# This is for GPU kernel downloading
# The AOTriton C++ runtime will be built from ../../3rdparty/aotriton
# Hence there is no need to add multiple ROCM version here

set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton")
set(__AOTRITON_SUFFIX "_TEprivate")

if(NOT DEFINED AOTRITON_PATH)
# Install aotriton fused attn
if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(AOTRITON_NOIMAGE_MODE OFF)
else()
set(AOTRITON_NOIMAGE_MODE ON)
endif()

string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
set(__AOTRITON_VER "0.11b")
set(__AOTRITON_SHA256
"a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed.
TE should never download the runtime and must build runtime from source with custom suffix, due to potential conflict with libaotriton shipped by pytorch.

)
set(__AOTRITON_IMAGE_LIST
"amd-gfx942"
"amd-gfx950"
)
set(__AOTRITON_IMAGE_SHA256_LIST
"3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942
"27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do not add other archs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what our support matrix was for archs

)
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
set(__AOTRITON_Z "gz")
# Set the default __AOTRITON_LIB path
set(__AOTRITON_LIB "lib/libaotriton_v2.so")
include(ExternalProject)
ExternalProject_Add(aotriton_external
LIST_SEPARATOR ","
SOURCE_DIR ${TE}/3rdparty/aotriton
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR}
-DGPU_TARGETS=${ARCH_LIST_COMMA_STR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"

function(aotriton_download_image image project)
list(FIND __AOTRITON_IMAGE_LIST ${image} index)
list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_IMAGE_SHA256)

string(CONCAT __AOTRITON_FILE
"aotriton-${__AOTRITON_VER}-images-"
"${image}.tar.${__AOTRITON_Z}")
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")

# Set up directories
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})

ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_IMAGE_SHA256}
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${__AOTRITON_EXTRACT_DIR}"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
)
message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.")
endfunction()

function(aotriton_build_from_source noimage)
if(noimage)
SET(RECURSIVE "OFF")
else()
SET(RECURSIVE "ON")
endif()
message(STATUS "No-image mode: ${noimage}.")
ExternalProject_Add(aotriton_external
LIST_SEPARATOR ","
SOURCE_DIR ${TE}/3rdparty/aotriton
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR}
-DGPU_TARGETS=${ARCH_LIST_COMMA_STR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
)
message(STATUS "Adding AOTriton library.")
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
)
endfunction()
string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
add_library(aotriton INTERFACE)
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
message(STATUS "Building AOTriton from source.")
aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE})
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(__AOTRITON_VER "0.10b")
set(__AOTRITON_SHA256 "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b")
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/"
"${__AOTRITON_VER}/aotriton-"
"${__AOTRITON_VER}-manylinux_2_28"
"_x86_64-rocm7.0"
"-shared.tar.gz")
set(aotriton_image_dirs)
foreach(X IN LISTS CMAKE_HIP_ARCHITECTURES)
list(APPEND aotriton_image_dirs "${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images/amd-${X}")
message(STATUS "Downloading AOTriton GPU Kernels.")
set(__AOTRITON_CHAINED_IMAGE "aotriton_external")
foreach(image ${__AOTRITON_IMAGE_LIST})
string(SUBSTRING ${image} 7 -1 gfx_pattern)
string(REPLACE "x" "." gfx_regex ${gfx_pattern})
foreach(target ${ARCH_LIST_COMMA_STR})
if(target MATCHES ${gfx_regex})
message(STATUS "Downloading AOTriton image ${image}.")
set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern})
aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET})
add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET})
set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET})
break()
endif()
endforeach()
endforeach()
set(aotriton_lib_install_dir "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images")
file(REMOVE_RECURSE ${aotriton_lib_install_dir})
file(MAKE_DIRECTORY ${aotriton_lib_install_dir})
ExternalProject_Add(aotriton_images
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
BUILD_ALWAYS TRUE
INSTALL_COMMAND cp -Ra ${aotriton_image_dirs} ${aotriton_lib_install_dir})
add_dependencies(aotriton aotriton_images)
else()
endif()
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE)


else()
# Use aotriton built during initial TE building/installation
# When only need rebuild TE library itself
Expand Down
9 changes: 8 additions & 1 deletion transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){
fused_attn_aotriton_fwd_qkvpacked(
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_QKV,
output_O, Aux_CTX_Tensors,
input_cu_seqlens,
Expand Down Expand Up @@ -490,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
fused_attn_aotriton_bwd_qkvpacked(
b, h, max_seqlen, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO, output_S,
output_dQKV,
Expand Down Expand Up @@ -576,6 +579,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_aotriton_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV,
output_O, Aux_CTX_Tensors,
Expand Down Expand Up @@ -675,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_aotriton_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
output_S,
Expand Down Expand Up @@ -759,6 +764,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_aotriton_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V,
output_O, Aux_CTX_Tensors,
Expand Down Expand Up @@ -854,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_aotriton_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
output_S,
Expand Down
Loading