Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aotriton
Submodule aotriton updated 146 files
140 changes: 96 additions & 44 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -278,67 +278,119 @@ else()
set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES})

if(USE_FUSED_ATTN_AOTRITON)
# This is for GPU kernel downloading
# The AOTriton C++ runtime will be built from ../../3rdparty/aotriton
# Hence there is no need to add multiple ROCM version here

set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton")
set(__AOTRITON_SUFFIX "_TEprivate")

if(NOT DEFINED AOTRITON_PATH)
# Install aotriton fused attn
# If AOTRITON_PATH is not provided, we proceed to build the runtime
# ourselves and either build or download the GPU kernels
if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(AOTRITON_NOIMAGE_MODE OFF)
else()
set(AOTRITON_NOIMAGE_MODE ON)
endif()

string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
set(__AOTRITON_VER "0.11b")
set(__AOTRITON_IMAGE_LIST
"amd-gfx942"
"amd-gfx950"
)
set(__AOTRITON_IMAGE_SHA256_LIST
"3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942
"27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do not add other archs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what our support matrix was for archs

)
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
set(__AOTRITON_Z "gz")
include(ExternalProject)
ExternalProject_Add(aotriton_external
LIST_SEPARATOR ","
SOURCE_DIR ${TE}/3rdparty/aotriton
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR}
-DGPU_TARGETS=${ARCH_LIST_COMMA_STR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"

# Download GPU kernels for a specific target
function(aotriton_download_image image project)
list(FIND __AOTRITON_IMAGE_LIST ${image} index)
list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_IMAGE_SHA256)

string(CONCAT __AOTRITON_FILE
"aotriton-${__AOTRITON_VER}-images-"
"${image}.tar.${__AOTRITON_Z}")
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")

# Set up directories
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})

ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_IMAGE_SHA256}
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${__AOTRITON_EXTRACT_DIR}"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
)
message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.")
endfunction()

# Build the AOTriton runtime from source with custom suffix to avoid
# potential conflict with libaotriton as provided by PyTorch
function(aotriton_build_from_source noimage)
message(STATUS "No-image mode: ${noimage}.")
ExternalProject_Add(aotriton_external
LIST_SEPARATOR ","
SOURCE_DIR ${TE}/3rdparty/aotriton
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR}
-DGPU_TARGETS=${ARCH_LIST_COMMA_STR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${noimage}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
)
message(STATUS "Adding AOTriton library.")
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
)
endfunction()

add_library(aotriton INTERFACE)
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
message(STATUS "Building AOTriton from source.")
string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE})

# Download GPU kernels if needed
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(__AOTRITON_VER "0.10b")
set(__AOTRITON_SHA256 "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b")
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/"
"${__AOTRITON_VER}/aotriton-"
"${__AOTRITON_VER}-manylinux_2_28"
"_x86_64-rocm7.0"
"-shared.tar.gz")
set(aotriton_image_dirs)
foreach(X IN LISTS CMAKE_HIP_ARCHITECTURES)
list(APPEND aotriton_image_dirs "${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images/amd-${X}")
message(STATUS "Downloading AOTriton GPU Kernels.")
set(__AOTRITON_CHAINED_IMAGE "aotriton_external")
foreach(image ${__AOTRITON_IMAGE_LIST})
string(SUBSTRING ${image} 7 -1 gfx_pattern)
string(REPLACE "x" "." gfx_regex ${gfx_pattern})
foreach(target ${ARCH_LIST_COMMA_STR})
if(target MATCHES ${gfx_regex})
message(STATUS "Downloading AOTriton image ${image}.")
set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern})
aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET})
add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET})
set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET})
break()
endif()
endforeach()
endforeach()
set(aotriton_lib_install_dir "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images")
file(REMOVE_RECURSE ${aotriton_lib_install_dir})
file(MAKE_DIRECTORY ${aotriton_lib_install_dir})
ExternalProject_Add(aotriton_images
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
BUILD_ALWAYS TRUE
INSTALL_COMMAND cp -Ra ${aotriton_image_dirs} ${aotriton_lib_install_dir})
add_dependencies(aotriton aotriton_images)
else()
endif()
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE)

else()
# Use aotriton built during initial TE building/installation
# When only need rebuild TE library itself
Expand Down
6 changes: 5 additions & 1 deletion transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){
fused_attn_aotriton_fwd_qkvpacked(
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_QKV,
output_O, Aux_CTX_Tensors,
input_cu_seqlens,
Expand Down Expand Up @@ -576,6 +578,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_aotriton_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV,
output_O, Aux_CTX_Tensors,
Expand Down Expand Up @@ -759,6 +762,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_aotriton_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk,
is_training, attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V,
output_O, Aux_CTX_Tensors,
Expand Down
Loading