From d10fa92a794293c743d7e3f7991a9ff305f7f6a7 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 24 Oct 2025 16:22:47 -0500 Subject: [PATCH 01/21] Initial commit --- 3rdparty/aotriton | 2 +- transformer_engine/common/CMakeLists.txt | 154 ++++++++++++++++------- 2 files changed, 113 insertions(+), 43 deletions(-) diff --git a/3rdparty/aotriton b/3rdparty/aotriton index 6fca155f4..972223c50 160000 --- a/3rdparty/aotriton +++ b/3rdparty/aotriton @@ -1 +1 @@ -Subproject commit 6fca155f4deeb8d9529326f7b69f350aeeb93477 +Subproject commit 972223c501ffc22068bb035ac5d64cf54318d895 diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index f70c9f8bb..8313cee50 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -278,67 +278,137 @@ else() set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) if(USE_FUSED_ATTN_AOTRITON) - # This is for GPU kernel downloading - # The AOTriton C++ runtime will be built from ../../3rdparty/aotriton - # Hence there is no need to add multiple ROCM version here + set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton") set(__AOTRITON_SUFFIX "_TEprivate") + if(NOT DEFINED AOTRITON_PATH) - # Install aotriton fused attn if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) set(AOTRITON_NOIMAGE_MODE OFF) else() set(AOTRITON_NOIMAGE_MODE ON) endif() - string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") + set(__AOTRITON_VER "0.11b") + set(__AOTRITON_SHA256 + "a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0 + ) + set(__AOTRITON_IMAGE_LIST + "amd-gfx942" + "amd-gfx950" + ) + set(__AOTRITON_IMAGE_SHA256_LIST + "3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942 + "27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950 + ) + set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore + set(__AOTRITON_Z "gz") + # Set the default __AOTRITON_LIB path + set(__AOTRITON_LIB "lib/libaotriton_v2.so") include(ExternalProject) - ExternalProject_Add(aotriton_external - LIST_SEPARATOR "," - SOURCE_DIR ${TE}/3rdparty/aotriton - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} - -DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR} - -DGPU_TARGETS=${ARCH_LIST_COMMA_STR} - -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DAOTRITON_NO_PYTHON=ON - -DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX} - -DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so" + + function(aotriton_download_image image project) + list(FIND __AOTRITON_IMAGE_LIST ${image} index) + list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_IMAGE_SHA256) + + string(CONCAT __AOTRITON_FILE + "aotriton-${__AOTRITON_VER}-images-" + "${image}.tar.${__AOTRITON_Z}") + string(CONCAT __AOTRITON_URL + "${__AOTRITON_BASE_URL}" + "${__AOTRITON_VER}/${__AOTRITON_FILE}") + + # Set up directories + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}) + + ExternalProject_Add(${project} + URL "${__AOTRITON_URL}" + URL_HASH SHA256=${__AOTRITON_IMAGE_SHA256} + DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} + SOURCE_DIR ${__AOTRITON_EXTRACT_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory + "${__AOTRITON_EXTRACT_DIR}" + "${__AOTRITON_INSTALL_DIR}" + BUILD_BYPRODUCTS + "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" ) - add_library(aotriton INTERFACE) - add_dependencies(aotriton aotriton_external) - target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so) - target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) - if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) - set(__AOTRITON_VER "0.10b") - set(__AOTRITON_SHA256 "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b") - string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" + message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.") + endfunction() + + function(aotriton_download_runtime) + message(STATUS "Preparing to download AOTriton runtime.") + string(CONCAT __AOTRITON_URL "${__AOTRITON_BASE_URL}" "${__AOTRITON_VER}/aotriton-" "${__AOTRITON_VER}-manylinux_2_28" "_x86_64-rocm7.0" "-shared.tar.gz") - set(aotriton_image_dirs) - foreach(X IN LISTS CMAKE_HIP_ARCHITECTURES) - list(APPEND aotriton_image_dirs "${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images/amd-${X}") - endforeach() - set(aotriton_lib_install_dir "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images") - file(REMOVE_RECURSE ${aotriton_lib_install_dir}) - file(MAKE_DIRECTORY ${aotriton_lib_install_dir}) + message(STATUS "Downloading AOTriton runtime from ${__AOTRITON_URL}.") ExternalProject_Add(aotriton_images - URL "${__AOTRITON_URL}" + URL ${__AOTRITON_URL} URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - BUILD_ALWAYS TRUE - INSTALL_COMMAND cp -Ra ${aotriton_image_dirs} ${aotriton_lib_install_dir}) + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_images + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory + "${CMAKE_CURRENT_BINARY_DIR}/aotriton_images" + "${__AOTRITON_INSTALL_DIR}" + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" + ) add_dependencies(aotriton aotriton_images) + message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ + Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") + endfunction() + + function(aotriton_build_from_source noimage) + if(noimage) + SET(RECURSIVE "OFF") + else() + SET(RECURSIVE "ON") + endif() + message(STATUS "No-image mode: ${noimage}.") + string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") + ExternalProject_Add(aotriton_external + LIST_SEPARATOR "," + SOURCE_DIR ${TE}/3rdparty/aotriton + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} + -DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR} + -DGPU_TARGETS=${ARCH_LIST_COMMA_STR} + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DAOTRITON_NO_PYTHON=ON + -DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX} + -DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE} + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so" + ) + message(STATUS "Adding AOTriton library.") + add_library(aotriton INTERFACE) + add_dependencies(aotriton aotriton_external) + target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so) + target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) + target_include_directories(aotriton INTERFACE ${TE}/3rdparty/aotriton/include) + endfunction() + if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) + # aotriton_download_runtime() + set(__AOTRITON_CHAINED_IMAGE "aotriton_external") + foreach(image ${__AOTRITON_IMAGE_LIST}) + string(SUBSTRING ${image} 7 -1 gfx_pattern) + string(REPLACE "x" "." gfx_regex ${gfx_pattern}) + foreach(target ${ARCH_LIST_COMMA_STR}) + if(target MATCHES ${gfx_regex}) + set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern}) + aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET}) + add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET}) + set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET}) + break() + endif() + endforeach() + endforeach() + else() + message(STATUS "Building AOTriton from source.") + aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) endif() - install(DIRECTORY - ${__AOTRITON_INSTALL_DIR}/lib - DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine - PATTERN "cmake" EXCLUDE - PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE) + + else() # Use aotriton built during initial TE building/installation # When only need rebuild TE library itself From eef7dc0e250cec622be319a46d5f7415300fe050 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 24 Oct 2025 16:40:41 -0500 Subject: [PATCH 02/21] Updated to build from source by default --- transformer_engine/common/CMakeLists.txt | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8313cee50..8185eee47 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -381,12 +381,18 @@ else() BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so" ) message(STATUS "Adding AOTriton library.") - add_library(aotriton INTERFACE) add_dependencies(aotriton aotriton_external) target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so) target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) - target_include_directories(aotriton INTERFACE ${TE}/3rdparty/aotriton/include) + install(DIRECTORY + ${__AOTRITON_INSTALL_DIR}/lib + DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine + PATTERN "cmake" EXCLUDE + ) endfunction() + add_library(aotriton INTERFACE) + message(STATUS "Building AOTriton from source.") + aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) # aotriton_download_runtime() set(__AOTRITON_CHAINED_IMAGE "aotriton_external") @@ -404,8 +410,6 @@ else() endforeach() endforeach() else() - message(STATUS "Building AOTriton from source.") - aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) endif() From cc68ab73877e2d9d60325baa00eb54b1aa60e2ff Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 31 Oct 2025 13:21:42 -0500 Subject: [PATCH 03/21] Updated for V3 API --- transformer_engine/common/CMakeLists.txt | 23 -- .../common/fused_attn_rocm/fused_attn.cpp | 9 +- .../fused_attn_rocm/fused_attn_aotriton.cpp | 387 ++++++++++++++---- .../fused_attn_rocm/fused_attn_aotriton.h | 6 + 4 files changed, 330 insertions(+), 95 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8185eee47..828ec51ff 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -338,28 +338,6 @@ else() message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.") endfunction() - function(aotriton_download_runtime) - message(STATUS "Preparing to download AOTriton runtime.") - string(CONCAT __AOTRITON_URL "${__AOTRITON_BASE_URL}" - "${__AOTRITON_VER}/aotriton-" - "${__AOTRITON_VER}-manylinux_2_28" - "_x86_64-rocm7.0" - "-shared.tar.gz") - message(STATUS "Downloading AOTriton runtime from ${__AOTRITON_URL}.") - ExternalProject_Add(aotriton_images - URL ${__AOTRITON_URL} - URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_images - INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory - "${CMAKE_CURRENT_BINARY_DIR}/aotriton_images" - "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" - ) - add_dependencies(aotriton aotriton_images) - message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ - Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") - endfunction() - function(aotriton_build_from_source noimage) if(noimage) SET(RECURSIVE "OFF") @@ -394,7 +372,6 @@ else() message(STATUS "Building AOTriton from source.") aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) - # aotriton_download_runtime() set(__AOTRITON_CHAINED_IMAGE "aotriton_external") foreach(image ${__AOTRITON_IMAGE_LIST}) string(SUBSTRING ${image} 7 -1 gfx_pattern) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 66fa72c0c..01b5989da 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -401,7 +401,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){ fused_attn_aotriton_fwd_qkvpacked( b, h, max_seqlen, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + is_training, attn_scale, dropout, + window_size_left, window_size_right, + qkv_layout, bias_type, attn_mask_type, input_QKV, output_O, Aux_CTX_Tensors, input_cu_seqlens, @@ -490,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con fused_attn_aotriton_bwd_qkvpacked( b, h, max_seqlen, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_QKV, input_O, input_dO, output_S, output_dQKV, @@ -576,6 +579,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const fused_attn_aotriton_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, output_O, Aux_CTX_Tensors, @@ -675,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_aotriton_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, input_dO, output_S, @@ -759,6 +764,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_aotriton_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, output_O, Aux_CTX_Tensors, @@ -854,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_aotriton_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, output_S, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 3fe7ec854..055548836 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -37,6 +37,39 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const uint64_t* ptr) namespace transformer_engine { namespace fused_attn_rocm { +bool get_pad_between_seqs( + const Tensor* input_cu_seqlens, + const Tensor* input_cu_seqlens_padded, + NVTE_QKV_Format qkv_format, NVTE_Mask_Type attn_mask_type +){ + bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; + bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + // First we check whether we have a ragged array with a non-trivial + // input_cu_seqlens_padded tensor + bool pad_between_seqs = ( + is_ragged + && input_cu_seqlens->data.dptr!=input_cu_seqlens_padded->data.dptr + && !input_cu_seqlens_padded->data.shape.empty() + ); + // Next we guard against an initial workspace-allocation which occurs in the + // JAX TE extension. We check for both pointers being null while retaining + // shape data, indicating the use of dummy data in the allocation pass. + pad_between_seqs = pad_between_seqs || ( + is_ragged + && input_cu_seqlens->data.dptr==nullptr && !input_cu_seqlens->data.shape.empty() + && input_cu_seqlens_padded->data.dptr==nullptr && !input_cu_seqlens_padded->data.shape.empty() + ); + // Finally we check whether we have an array with padding and non-empty input_cu_seqlens + pad_between_seqs = pad_between_seqs || ( + !is_ragged + && is_padding + && !input_cu_seqlens->data.shape.empty() + ); + return pad_between_seqs; +} + // check the fused attn config to see whether it's aotriton backend supported bool is_aotriton_backend_supported( NVTEDType q_dtype, @@ -127,12 +160,12 @@ aotriton::DType nvte_to_aotriton_dtype(DType t_dtype){ void fused_attn_aotriton_fwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int window_size_left, int window_size_right, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool pad_between_seqs, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrSoftmaxAux, void *devPtrO, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, - //void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, + void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, aotriton::DType dtype, void *workspace, size_t *workspace_size, @@ -163,6 +196,11 @@ void fused_attn_aotriton_fwd_impl( auto k_tensor = aotriton::TensorView<4>(reinterpret_cast(devPtrK), kv_shape, k_stride, dtype); auto v_tensor = aotriton::TensorView<4>(reinterpret_cast(devPtrV), kv_shape, v_stride, dtype); + // Cumulative seqlen tensors + std::array cu_seqlens_shape{b+1}; + std::array cu_seqlens_stride{1}; + auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); + auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); std::array o_stride; generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), @@ -185,6 +223,35 @@ void fused_attn_aotriton_fwd_impl( if (env_p != nullptr && std::string(env_p) == "1") nvte_log_aotriton_config = true; } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); + using aotriton::v3::flash::attn_fwd; + auto seed = mk_aoscalartensor(devPtrDropoutSeed); + auto offset1 = mk_aoscalartensor(devPtrDropoutOffset); + auto seed_output = mk_aoscalartensor(nullptr); + auto offset_output = mk_aoscalartensor(nullptr); + const auto is_causal = mask_type == NVTE_CAUSAL_MASK; + aotriton::TensorView<0> atomic_for_causal(reinterpret_cast(workspace), aotriton::DType::kInt32); + int8_t varlen_type = 0; + auto qkv_format = nvte_get_qkv_format(layout); + if(pad_between_seqs){ + varlen_type = 2; + }else if(qkv_format == NVTE_QKV_Format::NVTE_THD){ + varlen_type = 1; + } + + int window_left = 0; + int window_right = 0; + using aotriton::v3::flash::WindowValue; + if (is_causal) { + window_left = WindowValue::BottomRightAligned; + window_right = WindowValue::BottomRightAligned; + } + if (window_size_left>0 || window_size_right>0) { + window_left = (window_size_left>0)? window_size_left:window_left; + window_right = (window_size_right>0)? window_size_right:window_right; + } + using aotriton::v3::flash::CausalType; + int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; if (nvte_log_aotriton_config) { std::cout< empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v2::flash::attn_fwd; - auto seed = mk_aoscalartensor(devPtrDropoutSeed); - auto offset1 = mk_aoscalartensor(devPtrDropoutOffset); - auto offset2 = 0; - auto seed_output = mk_aoscalartensor(nullptr); - auto offset_output = mk_aoscalartensor(nullptr); - const auto is_causal = mask_type == NVTE_CAUSAL_MASK; - aotriton::TensorView<0> atomic_for_causal(reinterpret_cast(workspace), aotriton::DType::kInt32); + + aotriton::v3::flash::attn_fwd_params fwd_params{}; + fwd_params.Q = q_tensor; + fwd_params.K = k_tensor; + fwd_params.V = v_tensor; + // fwd_params.B = empty_bias; + // fwd_params.A = nullptr; // Alibi slopes, currently unused + fwd_params.Sm_scale = scaling_factor; + fwd_params.L = M_tensor; + fwd_params.Out = o_tensor; + if(varlen_type){ + fwd_params.cu_seqlens_q = cu_seqlens_q; + fwd_params.cu_seqlens_k = cu_seqlens_k; + fwd_params.Max_seqlen_q = s_q; // Unused if cu_seqlens_q is empty + fwd_params.Max_seqlen_k = s_kv; // Unused if cu_seqlens_k is empty + } + fwd_params.dropout_p = is_training? dropout_probability : 0; + fwd_params.philox_seed_ptr = seed; + fwd_params.philox_offset1 = offset1; + fwd_params.philox_offset2 = 0; + fwd_params.philox_seed_output = seed_output; + fwd_params.philox_offset_output = offset_output; + fwd_params.encoded_softmax = encoded_softmax_tensor; + fwd_params.persistent_atomic_counter = atomic_for_causal; + fwd_params.causal_type = causal_type; + fwd_params.varlen_type = varlen_type; + fwd_params.window_left = window_left; + fwd_params.window_right = window_right; + NVTE_CHECK_CUDA(hipMemsetAsync(workspace, 0, sizeof(int32_t), stream)); - NVTE_CHECK_CUDA(attn_fwd(q_tensor, - k_tensor, - v_tensor, - empty_bias, - scaling_factor, - M_tensor, - o_tensor, - is_training? dropout_probability : 0, - seed, - offset1, - offset2, - seed_output, - offset_output, - encoded_softmax_tensor, - is_causal, - atomic_for_causal, - stream)); + NVTE_CHECK_CUDA(attn_fwd(fwd_params, fwd_params.kVersion, stream)); } +// A thin conversion wrapper around eager tensor-views to lazy tensors +template +struct LazyTensorContext { + aotriton::TensorView tensor_view; +}; +template +struct LazyTensorFunctions { + static aotriton::TensorView acquire(void* cookie) { + return static_cast*>(cookie)->tensor_view; + } + static void dispose(void* cookie) { + } +}; + void fused_attn_aotriton_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + uint64_t window_size_left, uint64_t window_size_right, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool pad_between_seqs, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrO, void* devPtrSoftmaxAux, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, + void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, aotriton::DType dtype, @@ -246,12 +356,20 @@ void fused_attn_aotriton_bwd_impl( size_t *workspace_size, cudaStream_t stream) { + const uint64_t dq_acc_size = b*s_q*h*d*sizeof(float); + // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ - // CK only requires workspace for lse softmax + // AOTriton requires workspace for lse softmax *workspace_size = b*h*s_q*sizeof(float); + // AOTriton requires workspace for DQ_ACC + *workspace_size += dq_acc_size; return; } + void * softmax_lse_ptr = workspace; + workspace = static_cast(static_cast(workspace) + b*h*s_q*sizeof(float)); + void * dq_acc_ptr = workspace; + std::array q_stride; std::array k_stride; std::array v_stride; @@ -271,7 +389,7 @@ void fused_attn_aotriton_bwd_impl( std::array q_shape{b, h, s_q, d}; std::array kv_shape{b, hg, s_kv, d}; - // m and workspace are of the same shape and stride + // m and softmax_lse are of the same shape and stride std::array m_shape{b * h, s_q}; std::array m_stride{s_q, 1}; @@ -289,13 +407,57 @@ void fused_attn_aotriton_bwd_impl( // auxilary tensors auto M_tensor = aotriton::TensorView<2>(reinterpret_cast(devPtrSoftmaxAux), m_shape, m_stride, aotriton::DType::kFloat32); - auto wkspace_tensor = aotriton::TensorView<2>(reinterpret_cast(workspace), m_shape, m_stride, aotriton::DType::kFloat32); + auto softmax_lse_tensor = aotriton::TensorView<2>(reinterpret_cast(softmax_lse_ptr), m_shape, m_stride, aotriton::DType::kFloat32); + auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, q_stride, aotriton::DType::kFloat32); + NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream)); + + LazyTensorContext<4> dq_acc_ctx {.tensor_view = dq_acc_tensor}; + using LTF = LazyTensorFunctions<4>; + auto dq_acc_lazy = aotriton::LazyTensor<4> { + .cookie = &dq_acc_ctx, + .acquire = <F::acquire, + .dispose = <F::dispose + }; + + // Cumulative seqlen tensors + std::array cu_seqlens_shape{b+1}; + std::array cu_seqlens_stride{1}; + auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); + auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); bool nvte_log_aotriton_config = false; if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") nvte_log_aotriton_config = true; } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); + using aotriton::v2::flash::attn_bwd; + auto seed = mk_aoscalartensor(devPtrDropoutSeed); + auto offset = mk_aoscalartensor(devPtrDropoutOffset); + const auto is_causal = mask_type == NVTE_CAUSAL_MASK; + int8_t varlen_type = 0; + auto qkv_format = nvte_get_qkv_format(layout); + if(pad_between_seqs){ + varlen_type = 2; + }else if(qkv_format == NVTE_QKV_Format::NVTE_THD){ + varlen_type = 1; + } + int window_left = s_q; + int window_right = s_kv; + bool needs_swa = false; + using aotriton::v3::flash::WindowValue; + if (is_causal) { + window_left = WindowValue::BottomRightAligned; + window_right = WindowValue::BottomRightAligned; + } + if (window_size_left>0 || window_size_right>0) { + needs_swa = true; + window_left = (window_size_left>0)? window_size_left:window_left; + window_right = (window_size_right>0)? window_size_right:window_right; + } + using aotriton::v3::flash::CausalType; + int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (nvte_log_aotriton_config) { std::cout<*>(dq_acc_lazy.cookie)->tensor_view.data_ptr()<<"\n"; } - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v2::flash::attn_bwd; - auto seed = mk_aoscalartensor(devPtrDropoutSeed); - auto offset = mk_aoscalartensor(devPtrDropoutOffset); - const auto is_causal = mask_type == NVTE_CAUSAL_MASK; - NVTE_CHECK_CUDA(attn_bwd(q_tensor, - k_tensor, - v_tensor, - empty_bias, - scaling_factor, - o_tensor, - do_tensor, - dq_tensor, - dk_tensor, - dv_tensor, - empty_bias, - M_tensor, - wkspace_tensor, - dropout_probability, - seed, - offset, - 0, - is_causal, - stream)); + aotriton::v3::flash::attn_bwd_params bwd_params{}; + bwd_params.Q = q_tensor; + bwd_params.K = k_tensor; + bwd_params.V = v_tensor; + bwd_params.B = empty_bias; + bwd_params.Sm_scale = scaling_factor; + bwd_params.Out = o_tensor; + bwd_params.cu_seqlens_q = cu_seqlens_q; + bwd_params.cu_seqlens_k = cu_seqlens_k; + bwd_params.Max_seqlen_q = s_q; + bwd_params.Max_seqlen_k = s_kv; + bwd_params.DO = o_tensor; + bwd_params.DK = do_tensor; + bwd_params.DV = dq_tensor; + bwd_params.DQ = dv_tensor; + bwd_params.DB = empty_bias; + bwd_params.L = M_tensor; + // bwd_params.D = softmax_lse_tensor; // ??? + bwd_params.dropout_p = dropout_probability; + bwd_params.philox_seed_ptr = seed; + bwd_params.philox_offset1 = offset; + bwd_params.philox_offset2 = 0; + bwd_params.causal_type = is_causal; + bwd_params.varlen_type = varlen_type; + bwd_params.window_left = window_left; + bwd_params.window_right = window_right; + bwd_params.DQ_ACC = dq_acc_lazy; + + NVTE_CHECK_CUDA(attn_bwd(bwd_params, bwd_params.kVersion, stream)); } #endif // USE_FUSED_ATTN_AOTRITON } // namespace fused_attn_rocm @@ -343,6 +534,7 @@ using namespace transformer_engine::fused_attn_rocm; void fused_attn_aotriton_fwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, bool is_training, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -390,16 +582,24 @@ void fused_attn_aotriton_fwd_qkvpacked( } size_t workspace_size = 0; + bool pad_between_seqs = get_pad_between_seqs( + input_cu_seqlens, + input_cu_seqlens, + nvte_get_qkv_format(qkv_layout), + attn_mask_type + ); fused_attn_aotriton_fwd_impl( b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + window_left, window_right, qkv_layout, - bias_type, attn_mask_type, + bias_type, attn_mask_type, pad_between_seqs, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, + input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, nvte_to_aotriton_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -426,6 +626,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -460,16 +661,24 @@ void fused_attn_aotriton_bwd_qkvpacked( void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); size_t workspace_size = 0; + bool pad_between_seqs = get_pad_between_seqs( + input_cu_seqlens, + input_cu_seqlens, + nvte_get_qkv_format(qkv_layout), + attn_mask_type + ); fused_attn_aotriton_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, + window_left, window_right, qkv_layout, - bias_type, attn_mask_type, + bias_type, attn_mask_type, pad_between_seqs, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -498,6 +707,7 @@ void fused_attn_aotriton_bwd_qkvpacked( void fused_attn_aotriton_fwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -545,16 +755,24 @@ void fused_attn_aotriton_fwd_kvpacked( } size_t workspace_size = 0; + bool pad_between_seqs = get_pad_between_seqs( + input_cu_seqlens_q, + input_cu_seqlens_kv, + nvte_get_qkv_format(qkv_layout), + attn_mask_type + ); fused_attn_aotriton_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + window_left, window_right, qkv_layout, - bias_type, attn_mask_type, + bias_type, attn_mask_type, pad_between_seqs, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, nvte_to_aotriton_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -581,6 +799,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -618,16 +837,24 @@ void fused_attn_aotriton_bwd_kvpacked( void *devPtrSoftmaxStats = output_S->data.dptr; size_t workspace_size = 0; + bool pad_between_seqs = get_pad_between_seqs( + input_cu_seqlens_q, + input_cu_seqlens_kv, + nvte_get_qkv_format(qkv_layout), + attn_mask_type + ); fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_left, window_right, qkv_layout, - bias_type, attn_mask_type, + bias_type, attn_mask_type, pad_between_seqs, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -656,6 +883,7 @@ void fused_attn_aotriton_bwd_kvpacked( void fused_attn_aotriton_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -694,16 +922,24 @@ void fused_attn_aotriton_fwd( } size_t workspace_size = 0; + bool pad_between_seqs = get_pad_between_seqs( + input_cu_seqlens_q, + input_cu_seqlens_kv, + nvte_get_qkv_format(qkv_layout), + attn_mask_type + ); fused_attn_aotriton_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + window_left, window_right, qkv_layout, - bias_type, attn_mask_type, - devPtrQ, devPtrK, devPtrV, + bias_type, attn_mask_type, pad_between_seqs, + devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, nvte_to_aotriton_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -730,6 +966,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -755,16 +992,24 @@ void fused_attn_aotriton_bwd( void *devPtrSoftmaxStats = output_S->data.dptr; size_t workspace_size = 0; + bool pad_between_seqs = get_pad_between_seqs( + input_cu_seqlens_q, + input_cu_seqlens_kv, + nvte_get_qkv_format(qkv_layout), + attn_mask_type + ); fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_left, window_right, qkv_layout, - bias_type, attn_mask_type, + bias_type, attn_mask_type, pad_between_seqs, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index 5b9ef89a1..3730cdad6 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -35,6 +35,7 @@ bool is_aotriton_backend_supported( void fused_attn_aotriton_fwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, bool is_training, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -46,6 +47,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -58,6 +60,7 @@ void fused_attn_aotriton_bwd_qkvpacked( void fused_attn_aotriton_fwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -70,6 +73,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -83,6 +87,7 @@ void fused_attn_aotriton_bwd_kvpacked( void fused_attn_aotriton_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -95,6 +100,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, From 4455361ee0017ff8f4650e884b10f4fef05daeaf Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 3 Nov 2025 14:26:23 -0600 Subject: [PATCH 04/21] Fixed build, reverted AOTriton bwd changes (now V2) --- transformer_engine/common/CMakeLists.txt | 4 +- .../fused_attn_rocm/fused_attn_aotriton.cpp | 179 ++++-------------- 2 files changed, 38 insertions(+), 145 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 828ec51ff..2213f48b4 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -345,7 +345,6 @@ else() SET(RECURSIVE "ON") endif() message(STATUS "No-image mode: ${noimage}.") - string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") ExternalProject_Add(aotriton_external LIST_SEPARATOR "," SOURCE_DIR ${TE}/3rdparty/aotriton @@ -368,16 +367,19 @@ else() PATTERN "cmake" EXCLUDE ) endfunction() + string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") add_library(aotriton INTERFACE) message(STATUS "Building AOTriton from source.") aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) + message(STATUS "Downloading AOTriton GPU Kernels.") set(__AOTRITON_CHAINED_IMAGE "aotriton_external") foreach(image ${__AOTRITON_IMAGE_LIST}) string(SUBSTRING ${image} 7 -1 gfx_pattern) string(REPLACE "x" "." gfx_regex ${gfx_pattern}) foreach(target ${ARCH_LIST_COMMA_STR}) if(target MATCHES ${gfx_regex}) + message(STATUS "Downloading AOTriton image ${image}.") set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern}) aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET}) add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET}) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 055548836..50722bf17 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -342,13 +342,12 @@ struct LazyTensorFunctions { void fused_attn_aotriton_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, float scaling_factor, float dropout_probability, - uint64_t window_size_left, uint64_t window_size_right, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool pad_between_seqs, + NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrO, void* devPtrSoftmaxAux, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, - void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, aotriton::DType dtype, @@ -356,20 +355,12 @@ void fused_attn_aotriton_bwd_impl( size_t *workspace_size, cudaStream_t stream) { - const uint64_t dq_acc_size = b*s_q*h*d*sizeof(float); - // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ - // AOTriton requires workspace for lse softmax + // CK only requires workspace for lse softmax *workspace_size = b*h*s_q*sizeof(float); - // AOTriton requires workspace for DQ_ACC - *workspace_size += dq_acc_size; return; } - void * softmax_lse_ptr = workspace; - workspace = static_cast(static_cast(workspace) + b*h*s_q*sizeof(float)); - void * dq_acc_ptr = workspace; - std::array q_stride; std::array k_stride; std::array v_stride; @@ -389,7 +380,7 @@ void fused_attn_aotriton_bwd_impl( std::array q_shape{b, h, s_q, d}; std::array kv_shape{b, hg, s_kv, d}; - // m and softmax_lse are of the same shape and stride + // m and workspace are of the same shape and stride std::array m_shape{b * h, s_q}; std::array m_stride{s_q, 1}; @@ -407,57 +398,13 @@ void fused_attn_aotriton_bwd_impl( // auxilary tensors auto M_tensor = aotriton::TensorView<2>(reinterpret_cast(devPtrSoftmaxAux), m_shape, m_stride, aotriton::DType::kFloat32); - auto softmax_lse_tensor = aotriton::TensorView<2>(reinterpret_cast(softmax_lse_ptr), m_shape, m_stride, aotriton::DType::kFloat32); - auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, q_stride, aotriton::DType::kFloat32); - NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream)); - - LazyTensorContext<4> dq_acc_ctx {.tensor_view = dq_acc_tensor}; - using LTF = LazyTensorFunctions<4>; - auto dq_acc_lazy = aotriton::LazyTensor<4> { - .cookie = &dq_acc_ctx, - .acquire = <F::acquire, - .dispose = <F::dispose - }; - - // Cumulative seqlen tensors - std::array cu_seqlens_shape{b+1}; - std::array cu_seqlens_stride{1}; - auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); - auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); + auto wkspace_tensor = aotriton::TensorView<2>(reinterpret_cast(workspace), m_shape, m_stride, aotriton::DType::kFloat32); bool nvte_log_aotriton_config = false; if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") nvte_log_aotriton_config = true; } - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v2::flash::attn_bwd; - auto seed = mk_aoscalartensor(devPtrDropoutSeed); - auto offset = mk_aoscalartensor(devPtrDropoutOffset); - const auto is_causal = mask_type == NVTE_CAUSAL_MASK; - int8_t varlen_type = 0; - auto qkv_format = nvte_get_qkv_format(layout); - if(pad_between_seqs){ - varlen_type = 2; - }else if(qkv_format == NVTE_QKV_Format::NVTE_THD){ - varlen_type = 1; - } - int window_left = s_q; - int window_right = s_kv; - bool needs_swa = false; - using aotriton::v3::flash::WindowValue; - if (is_causal) { - window_left = WindowValue::BottomRightAligned; - window_right = WindowValue::BottomRightAligned; - } - if (window_size_left>0 || window_size_right>0) { - needs_swa = true; - window_left = (window_size_left>0)? window_size_left:window_left; - window_right = (window_size_right>0)? window_size_right:window_right; - } - using aotriton::v3::flash::CausalType; - int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; - if (nvte_log_aotriton_config) { std::cout<*>(dq_acc_lazy.cookie)->tensor_view.data_ptr()<<"\n"; + std::cout<<"dropout_p: "< empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); + using aotriton::v2::flash::attn_bwd; + auto seed = mk_aoscalartensor(devPtrDropoutSeed); + auto offset = mk_aoscalartensor(devPtrDropoutOffset); + const auto is_causal = mask_type == NVTE_CAUSAL_MASK; + NVTE_CHECK_CUDA(attn_bwd(q_tensor, + k_tensor, + v_tensor, + empty_bias, + scaling_factor, + o_tensor, + do_tensor, + dq_tensor, + dk_tensor, + dv_tensor, + empty_bias, + M_tensor, + wkspace_tensor, + dropout_probability, + seed, + offset, + 0, + is_causal, + stream)); } #endif // USE_FUSED_ATTN_AOTRITON } // namespace fused_attn_rocm @@ -661,24 +579,15 @@ void fused_attn_aotriton_bwd_qkvpacked( void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); size_t workspace_size = 0; - bool pad_between_seqs = get_pad_between_seqs( - input_cu_seqlens, - input_cu_seqlens, - nvte_get_qkv_format(qkv_layout), - attn_mask_type - ); - fused_attn_aotriton_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, - window_left, window_right, qkv_layout, - bias_type, attn_mask_type, pad_between_seqs, + bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -837,24 +746,15 @@ void fused_attn_aotriton_bwd_kvpacked( void *devPtrSoftmaxStats = output_S->data.dptr; size_t workspace_size = 0; - bool pad_between_seqs = get_pad_between_seqs( - input_cu_seqlens_q, - input_cu_seqlens_kv, - nvte_get_qkv_format(qkv_layout), - attn_mask_type - ); - fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - window_left, window_right, qkv_layout, - bias_type, attn_mask_type, pad_between_seqs, + bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -992,24 +892,15 @@ void fused_attn_aotriton_bwd( void *devPtrSoftmaxStats = output_S->data.dptr; size_t workspace_size = 0; - bool pad_between_seqs = get_pad_between_seqs( - input_cu_seqlens_q, - input_cu_seqlens_kv, - nvte_get_qkv_format(qkv_layout), - attn_mask_type - ); - fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - window_left, window_right, qkv_layout, - bias_type, attn_mask_type, pad_between_seqs, + bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), From 2586b185f39ebca57be4b304fdd6dbee5d2bdae0 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 3 Nov 2025 14:37:15 -0600 Subject: [PATCH 05/21] Removed alterations --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 3 --- .../common/fused_attn_rocm/fused_attn_aotriton.h | 3 --- 2 files changed, 6 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 50722bf17..e21aedfdd 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -544,7 +544,6 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, - uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -708,7 +707,6 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, - uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -866,7 +864,6 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, - uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index 3730cdad6..0e44058dd 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -47,7 +47,6 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, - uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -73,7 +72,6 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, - uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -100,7 +98,6 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, - uint64_t window_left, uint64_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, From aa80f81ef6d8489b8419ab710f6a94577b917c62 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 3 Nov 2025 16:13:50 -0600 Subject: [PATCH 06/21] Removed lazy tensor wrapper --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index e21aedfdd..ccc515f1b 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -325,20 +325,6 @@ void fused_attn_aotriton_fwd_impl( NVTE_CHECK_CUDA(attn_fwd(fwd_params, fwd_params.kVersion, stream)); } -// A thin conversion wrapper around eager tensor-views to lazy tensors -template -struct LazyTensorContext { - aotriton::TensorView tensor_view; -}; -template -struct LazyTensorFunctions { - static aotriton::TensorView acquire(void* cookie) { - return static_cast*>(cookie)->tensor_view; - } - static void dispose(void* cookie) { - } -}; - void fused_attn_aotriton_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, float scaling_factor, float dropout_probability, From 9a91b9ee91e53f67558dae886dd3cfd69fb28511 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 4 Nov 2025 12:23:45 -0600 Subject: [PATCH 07/21] Streamlined cmakelist, other PR review feedback adressed --- transformer_engine/common/CMakeLists.txt | 25 ++++++----- .../common/fused_attn_rocm/fused_attn.cpp | 3 -- .../fused_attn_rocm/fused_attn_aotriton.cpp | 43 ++++++++++--------- .../fused_attn_rocm/fused_attn_aotriton.h | 6 +-- 4 files changed, 37 insertions(+), 40 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2213f48b4..d15332530 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -278,11 +278,15 @@ else() set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) if(USE_FUSED_ATTN_AOTRITON) + # The AOTriton C++ runtime will be built from ../../3rdparty/aotriton + # Hence there is no need to add multiple ROCM version here set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton") set(__AOTRITON_SUFFIX "_TEprivate") if(NOT DEFINED AOTRITON_PATH) + # If AOTRITON_PATH is not provided, we proceed to build the runtime + # ourselves and either build or download the GPU kernels if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) set(AOTRITON_NOIMAGE_MODE OFF) else() @@ -290,9 +294,6 @@ else() endif() set(__AOTRITON_VER "0.11b") - set(__AOTRITON_SHA256 - "a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0 - ) set(__AOTRITON_IMAGE_LIST "amd-gfx942" "amd-gfx950" @@ -303,10 +304,9 @@ else() ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") - # Set the default __AOTRITON_LIB path - set(__AOTRITON_LIB "lib/libaotriton_v2.so") include(ExternalProject) + # Download GPU kernels for a specific target function(aotriton_download_image image project) list(FIND __AOTRITON_IMAGE_LIST ${image} index) list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_IMAGE_SHA256) @@ -338,12 +338,9 @@ else() message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.") endfunction() + # Build the AOTriton runtime from source with custom suffix to avoid + # potential conflict with libaotriton as provided by PyTorch function(aotriton_build_from_source noimage) - if(noimage) - SET(RECURSIVE "OFF") - else() - SET(RECURSIVE "ON") - endif() message(STATUS "No-image mode: ${noimage}.") ExternalProject_Add(aotriton_external LIST_SEPARATOR "," @@ -354,7 +351,7 @@ else() -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX} - -DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE} + -DAOTRITON_NOIMAGE_MODE=${noimage} BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so" ) message(STATUS "Adding AOTriton library.") @@ -367,10 +364,13 @@ else() PATTERN "cmake" EXCLUDE ) endfunction() - string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") + add_library(aotriton INTERFACE) message(STATUS "Building AOTriton from source.") + string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) + + # Download GPU kernels if needed if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) message(STATUS "Downloading AOTriton GPU Kernels.") set(__AOTRITON_CHAINED_IMAGE "aotriton_external") @@ -391,7 +391,6 @@ else() else() endif() - else() # Use aotriton built during initial TE building/installation # When only need rebuild TE library itself diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 01b5989da..392a73bfa 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -492,7 +492,6 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con fused_attn_aotriton_bwd_qkvpacked( b, h, max_seqlen, d, attn_scale, dropout, - window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_QKV, input_O, input_dO, output_S, output_dQKV, @@ -679,7 +678,6 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_aotriton_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, input_dO, output_S, @@ -860,7 +858,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_aotriton_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, output_S, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index ccc515f1b..3eb623716 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -70,6 +70,20 @@ bool get_pad_between_seqs( return pad_between_seqs; } +std::tuple get_window_sizes(int window_size_left, int window_size_right, bool is_causal){ + int window_left = 0; + int window_right = 0; + using aotriton::v3::flash::WindowValue; + if (is_causal) { + window_left = WindowValue::BottomRightAligned; + window_right = WindowValue::BottomRightAligned; + } else if (window_size_left>0 || window_size_right>0) { + window_left = (window_size_left>0)? window_size_left:window_left; + window_right = (window_size_right>0)? window_size_right:window_right; + } + return {window_left, window_right}; +} + // check the fused attn config to see whether it's aotriton backend supported bool is_aotriton_backend_supported( NVTEDType q_dtype, @@ -231,27 +245,14 @@ void fused_attn_aotriton_fwd_impl( auto offset_output = mk_aoscalartensor(nullptr); const auto is_causal = mask_type == NVTE_CAUSAL_MASK; aotriton::TensorView<0> atomic_for_causal(reinterpret_cast(workspace), aotriton::DType::kInt32); - int8_t varlen_type = 0; - auto qkv_format = nvte_get_qkv_format(layout); - if(pad_between_seqs){ - varlen_type = 2; - }else if(qkv_format == NVTE_QKV_Format::NVTE_THD){ - varlen_type = 1; - } - int window_left = 0; - int window_right = 0; - using aotriton::v3::flash::WindowValue; - if (is_causal) { - window_left = WindowValue::BottomRightAligned; - window_right = WindowValue::BottomRightAligned; - } - if (window_size_left>0 || window_size_right>0) { - window_left = (window_size_left>0)? window_size_left:window_left; - window_right = (window_size_right>0)? window_size_right:window_right; - } + using aotriton::v3::flash::VarlenType; + int8_t varlen_type = VarlenType::None; + + auto [window_left, window_right] = get_window_sizes(window_size_left, window_size_right, is_causal); using aotriton::v3::flash::CausalType; int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (nvte_log_aotriton_config) { std::cout< Date: Tue, 4 Nov 2025 12:27:17 -0600 Subject: [PATCH 08/21] Removed `pad_between_seqs` --- .../fused_attn_rocm/fused_attn_aotriton.cpp | 61 ++----------------- 1 file changed, 5 insertions(+), 56 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 3eb623716..121834dc2 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -37,40 +37,7 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const uint64_t* ptr) namespace transformer_engine { namespace fused_attn_rocm { -bool get_pad_between_seqs( - const Tensor* input_cu_seqlens, - const Tensor* input_cu_seqlens_padded, - NVTE_QKV_Format qkv_format, NVTE_Mask_Type attn_mask_type -){ - bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - // First we check whether we have a ragged array with a non-trivial - // input_cu_seqlens_padded tensor - bool pad_between_seqs = ( - is_ragged - && input_cu_seqlens->data.dptr!=input_cu_seqlens_padded->data.dptr - && !input_cu_seqlens_padded->data.shape.empty() - ); - // Next we guard against an initial workspace-allocation which occurs in the - // JAX TE extension. We check for both pointers being null while retaining - // shape data, indicating the use of dummy data in the allocation pass. - pad_between_seqs = pad_between_seqs || ( - is_ragged - && input_cu_seqlens->data.dptr==nullptr && !input_cu_seqlens->data.shape.empty() - && input_cu_seqlens_padded->data.dptr==nullptr && !input_cu_seqlens_padded->data.shape.empty() - ); - // Finally we check whether we have an array with padding and non-empty input_cu_seqlens - pad_between_seqs = pad_between_seqs || ( - !is_ragged - && is_padding - && !input_cu_seqlens->data.shape.empty() - ); - return pad_between_seqs; -} - -std::tuple get_window_sizes(int window_size_left, int window_size_right, bool is_causal){ + std::tuple get_window_sizes(int window_size_left, int window_size_right, bool is_causal){ int window_left = 0; int window_right = 0; using aotriton::v3::flash::WindowValue; @@ -175,7 +142,7 @@ void fused_attn_aotriton_fwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, bool is_training, float scaling_factor, float dropout_probability, int window_size_left, int window_size_right, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool pad_between_seqs, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrSoftmaxAux, void *devPtrO, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, @@ -487,19 +454,13 @@ void fused_attn_aotriton_fwd_qkvpacked( } size_t workspace_size = 0; - bool pad_between_seqs = get_pad_between_seqs( - input_cu_seqlens, - input_cu_seqlens, - nvte_get_qkv_format(qkv_layout), - attn_mask_type - ); fused_attn_aotriton_fwd_impl( b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, window_left, window_right, qkv_layout, - bias_type, attn_mask_type, pad_between_seqs, + bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), @@ -650,19 +611,13 @@ void fused_attn_aotriton_fwd_kvpacked( } size_t workspace_size = 0; - bool pad_between_seqs = get_pad_between_seqs( - input_cu_seqlens_q, - input_cu_seqlens_kv, - nvte_get_qkv_format(qkv_layout), - attn_mask_type - ); fused_attn_aotriton_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, window_left, window_right, qkv_layout, - bias_type, attn_mask_type, pad_between_seqs, + bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), @@ -807,19 +762,13 @@ void fused_attn_aotriton_fwd( } size_t workspace_size = 0; - bool pad_between_seqs = get_pad_between_seqs( - input_cu_seqlens_q, - input_cu_seqlens_kv, - nvte_get_qkv_format(qkv_layout), - attn_mask_type - ); fused_attn_aotriton_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, window_left, window_right, qkv_layout, - bias_type, attn_mask_type, pad_between_seqs, + bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), From 6b8dbe556e9a9c40e952994ebe7778428590dff3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 4 Nov 2025 12:29:33 -0600 Subject: [PATCH 09/21] Updated typing to be more explicit --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 121834dc2..31c44b203 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -37,9 +37,9 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const uint64_t* ptr) namespace transformer_engine { namespace fused_attn_rocm { - std::tuple get_window_sizes(int window_size_left, int window_size_right, bool is_causal){ - int window_left = 0; - int window_right = 0; + std::tuple get_window_sizes(int32_t window_size_left, int32_t window_size_right, bool is_causal){ + int32_t window_left = 0; + int32_t window_right = 0; using aotriton::v3::flash::WindowValue; if (is_causal) { window_left = WindowValue::BottomRightAligned; @@ -141,7 +141,7 @@ aotriton::DType nvte_to_aotriton_dtype(DType t_dtype){ void fused_attn_aotriton_fwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, bool is_training, float scaling_factor, float dropout_probability, - int window_size_left, int window_size_right, NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrSoftmaxAux, void *devPtrO, From 68303d063854c529cb90675cb680e2ee6737ce8b Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 4 Nov 2025 14:58:56 -0600 Subject: [PATCH 10/21] Minor streamlining and formatting --- transformer_engine/common/CMakeLists.txt | 11 ++++------- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 10 +++++++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index d15332530..2f2e4fb42 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -303,7 +303,6 @@ else() "27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950 ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore - set(__AOTRITON_Z "gz") include(ExternalProject) # Download GPU kernels for a specific target @@ -313,14 +312,14 @@ else() string(CONCAT __AOTRITON_FILE "aotriton-${__AOTRITON_VER}-images-" - "${image}.tar.${__AOTRITON_Z}") + "${image}.tar.gz") string(CONCAT __AOTRITON_URL "${__AOTRITON_BASE_URL}" "${__AOTRITON_VER}/${__AOTRITON_FILE}") # Set up directories - set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image}) - set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}) + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton/download/${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton/image/${image}) ExternalProject_Add(${project} URL "${__AOTRITON_URL}" @@ -373,7 +372,6 @@ else() # Download GPU kernels if needed if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) message(STATUS "Downloading AOTriton GPU Kernels.") - set(__AOTRITON_CHAINED_IMAGE "aotriton_external") foreach(image ${__AOTRITON_IMAGE_LIST}) string(SUBSTRING ${image} 7 -1 gfx_pattern) string(REPLACE "x" "." gfx_regex ${gfx_pattern}) @@ -382,8 +380,7 @@ else() message(STATUS "Downloading AOTriton image ${image}.") set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern}) aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET}) - add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET}) - set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET}) + add_dependencies(aotriton_external ${__AOTRITON_DOWNLOAD_TARGET}) break() endif() endforeach() diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 31c44b203..6202c3518 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -37,14 +37,18 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const uint64_t* ptr) namespace transformer_engine { namespace fused_attn_rocm { - std::tuple get_window_sizes(int32_t window_size_left, int32_t window_size_right, bool is_causal){ +std::tuple get_window_sizes( + int32_t window_size_left, + int32_t window_size_right, + bool is_causal +){ int32_t window_left = 0; int32_t window_right = 0; using aotriton::v3::flash::WindowValue; - if (is_causal) { + if(is_causal){ window_left = WindowValue::BottomRightAligned; window_right = WindowValue::BottomRightAligned; - } else if (window_size_left>0 || window_size_right>0) { + }else{ window_left = (window_size_left>0)? window_size_left:window_left; window_right = (window_size_right>0)? window_size_right:window_right; } From 8181972e887657198d201fe8bccf1027a9aced20 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 6 Nov 2025 12:36:02 -0600 Subject: [PATCH 11/21] Initial implementation --- .../common/fused_attn_rocm/fused_attn.cpp | 3 + .../fused_attn_rocm/fused_attn_aotriton.cpp | 157 ++++++++++++++---- .../fused_attn_rocm/fused_attn_aotriton.h | 3 + 3 files changed, 132 insertions(+), 31 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 392a73bfa..01b5989da 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -492,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con fused_attn_aotriton_bwd_qkvpacked( b, h, max_seqlen, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_QKV, input_O, input_dO, output_S, output_dQKV, @@ -678,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_aotriton_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, input_dO, output_S, @@ -858,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_aotriton_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, output_S, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 6202c3518..e34da2bed 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -297,15 +297,30 @@ void fused_attn_aotriton_fwd_impl( NVTE_CHECK_CUDA(attn_fwd(fwd_params, fwd_params.kVersion, stream)); } +// A thin conversion wrapper around eager tensor-views to lazy tensors +template +struct LazyTensorContext { + aotriton::TensorView tensor_view; +}; +template +struct LazyTensorFunctions { + static aotriton::TensorView acquire(void* cookie) { + return static_cast*>(cookie)->tensor_view; + } + static void dispose(void* cookie) { + } +}; + void fused_attn_aotriton_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrO, void* devPtrSoftmaxAux, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, + void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, aotriton::DType dtype, @@ -313,12 +328,20 @@ void fused_attn_aotriton_bwd_impl( size_t *workspace_size, cudaStream_t stream) { + const uint64_t dq_acc_size = b*s_q*h*d*sizeof(float); + // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ - // CK only requires workspace for lse softmax + // AOTriton requires workspace for lse softmax *workspace_size = b*h*s_q*sizeof(float); + // AOTriton requires workspace for DQ_ACC + *workspace_size += dq_acc_size; return; } + void * delta = workspace; + workspace = static_cast(static_cast(workspace) + b*h*s_q*sizeof(float)); + void * dq_acc_ptr = workspace; + std::array q_stride; std::array k_stride; std::array v_stride; @@ -338,7 +361,7 @@ void fused_attn_aotriton_bwd_impl( std::array q_shape{b, h, s_q, d}; std::array kv_shape{b, hg, s_kv, d}; - // m and workspace are of the same shape and stride + // m and softmax_lse are of the same shape and stride std::array m_shape{b * h, s_q}; std::array m_stride{s_q, 1}; @@ -356,13 +379,47 @@ void fused_attn_aotriton_bwd_impl( // auxilary tensors auto M_tensor = aotriton::TensorView<2>(reinterpret_cast(devPtrSoftmaxAux), m_shape, m_stride, aotriton::DType::kFloat32); - auto wkspace_tensor = aotriton::TensorView<2>(reinterpret_cast(workspace), m_shape, m_stride, aotriton::DType::kFloat32); + auto delta_tensor = aotriton::TensorView<2>(reinterpret_cast(delta), m_shape, m_stride, aotriton::DType::kFloat32); + auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, q_stride, aotriton::DType::kFloat32); + NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream)); + + LazyTensorContext<4> dq_acc_ctx {.tensor_view = dq_acc_tensor}; + LazyTensorContext<2> delta_ctx {.tensor_view = delta_tensor}; + auto dq_acc_lazy = aotriton::LazyTensor<4> { + .cookie = &dq_acc_ctx, + .acquire = &LazyTensorFunctions<4>::acquire, + .dispose = &LazyTensorFunctions<4>::dispose + }; + auto delta_lazy = aotriton::LazyTensor<2> { + .cookie = &delta_ctx, + .acquire = &LazyTensorFunctions<2>::acquire, + .dispose = &LazyTensorFunctions<2>::dispose + }; + + // Cumulative seqlen tensors + std::array cu_seqlens_shape{b+1}; + std::array cu_seqlens_stride{1}; + auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); + auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); bool nvte_log_aotriton_config = false; if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") nvte_log_aotriton_config = true; } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); + using aotriton::v2::flash::attn_bwd; + auto seed = mk_aoscalartensor(devPtrDropoutSeed); + auto offset = mk_aoscalartensor(devPtrDropoutOffset); + const auto is_causal = mask_type == NVTE_CAUSAL_MASK; + + using aotriton::v3::flash::VarlenType; + int8_t varlen_type = VarlenType::None; + + auto [window_left, window_right] = get_window_sizes(window_size_left, window_size_right, is_causal); + using aotriton::v3::flash::CausalType; + int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (nvte_log_aotriton_config) { std::cout<*>(delta_lazy.cookie)->tensor_view.data_ptr()<<"\n"; + std::cout<<"dropout_p: "<*>(dq_acc_lazy.cookie)->tensor_view.data_ptr()<<"\n"; } - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v2::flash::attn_bwd; - auto seed = mk_aoscalartensor(devPtrDropoutSeed); - auto offset = mk_aoscalartensor(devPtrDropoutOffset); - const auto is_causal = mask_type == NVTE_CAUSAL_MASK; - NVTE_CHECK_CUDA(attn_bwd(q_tensor, - k_tensor, - v_tensor, - empty_bias, - scaling_factor, - o_tensor, - do_tensor, - dq_tensor, - dk_tensor, - dv_tensor, - empty_bias, - M_tensor, - wkspace_tensor, - dropout_probability, - seed, - offset, - 0, - is_causal, - stream)); + aotriton::v3::flash::attn_bwd_params bwd_params{}; + bwd_params.Q = q_tensor; + bwd_params.K = k_tensor; + bwd_params.V = v_tensor; + bwd_params.B = empty_bias; + bwd_params.Sm_scale = scaling_factor; + bwd_params.Out = o_tensor; + if(varlen_type){ + bwd_params.cu_seqlens_q = cu_seqlens_q; + bwd_params.cu_seqlens_k = cu_seqlens_k; + bwd_params.Max_seqlen_q = s_q; + bwd_params.Max_seqlen_k = s_kv; + } + bwd_params.DO = o_tensor; + bwd_params.DK = do_tensor; + bwd_params.DV = dq_tensor; + bwd_params.DQ = dv_tensor; + bwd_params.DB = empty_bias; + bwd_params.L = M_tensor; + bwd_params.D = delta_lazy; + bwd_params.dropout_p = dropout_probability; + bwd_params.philox_seed_ptr = seed; + bwd_params.philox_offset1 = offset; + bwd_params.philox_offset2 = 0; + bwd_params.causal_type = is_causal; + bwd_params.varlen_type = varlen_type; + bwd_params.window_left = window_left; + bwd_params.window_right = window_right; + bwd_params.DQ_ACC = dq_acc_lazy; + + NVTE_CHECK_CUDA(attn_bwd(bwd_params, bwd_params.kVersion, stream)); } #endif // USE_FUSED_ATTN_AOTRITON } // namespace fused_attn_rocm @@ -496,6 +585,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -533,6 +623,7 @@ void fused_attn_aotriton_bwd_qkvpacked( fused_attn_aotriton_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, @@ -653,6 +744,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -693,6 +785,7 @@ void fused_attn_aotriton_bwd_kvpacked( fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, @@ -804,6 +897,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -832,6 +926,7 @@ void fused_attn_aotriton_bwd( fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index b016acc67..7818d2b33 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -47,6 +47,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -72,6 +73,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -98,6 +100,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, From 6788a162454b3995709233aa851e1f704addfe85 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 6 Nov 2025 13:22:46 -0600 Subject: [PATCH 12/21] Simplified window size func for current non-SWA support --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 6202c3518..fa923b243 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -17,6 +17,7 @@ #include "../util/system.h" #include "fused_attn_aotriton.h" #include "utils.h" +#include #ifdef USE_FUSED_ATTN_AOTRITON #if AOTRITON_ENABLE_SUFFIX @@ -37,22 +38,17 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const uint64_t* ptr) namespace transformer_engine { namespace fused_attn_rocm { +// TODO: Support SWA std::tuple get_window_sizes( int32_t window_size_left, int32_t window_size_right, bool is_causal ){ - int32_t window_left = 0; - int32_t window_right = 0; using aotriton::v3::flash::WindowValue; if(is_causal){ - window_left = WindowValue::BottomRightAligned; - window_right = WindowValue::BottomRightAligned; - }else{ - window_left = (window_size_left>0)? window_size_left:window_left; - window_right = (window_size_right>0)? window_size_right:window_right; + return {WindowValue::BottomRightAligned, WindowValue::BottomRightAligned}; } - return {window_left, window_right}; + return {-1, -1}; } // check the fused attn config to see whether it's aotriton backend supported From 182101ab9ae1a51b5a52a94404fd595fa589ec3a Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 6 Nov 2025 13:28:12 -0600 Subject: [PATCH 13/21] Removed accidental include --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index fa923b243..72d0d28bf 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -17,7 +17,6 @@ #include "../util/system.h" #include "fused_attn_aotriton.h" #include "utils.h" -#include #ifdef USE_FUSED_ATTN_AOTRITON #if AOTRITON_ENABLE_SUFFIX From fef6baa113317a561f3a1b3e902d66cb09851868 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 6 Nov 2025 16:18:00 -0600 Subject: [PATCH 14/21] Corrected bwd args --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 13 ++++++++----- .../common/fused_attn_rocm/fused_attn_aotriton.h | 6 +++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 9ac522cf6..4526e9ae9 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -468,13 +468,13 @@ void fused_attn_aotriton_bwd_impl( bwd_params.Max_seqlen_q = s_q; bwd_params.Max_seqlen_k = s_kv; } - bwd_params.DO = o_tensor; - bwd_params.DK = do_tensor; - bwd_params.DV = dq_tensor; - bwd_params.DQ = dv_tensor; + bwd_params.DO = do_tensor; + bwd_params.DK = dk_tensor; + bwd_params.DV = dv_tensor; + bwd_params.DQ = dq_tensor; bwd_params.DB = empty_bias; bwd_params.L = M_tensor; - bwd_params.D = delta_lazy; + bwd_params.D = delta_lazy; bwd_params.dropout_p = dropout_probability; bwd_params.philox_seed_ptr = seed; bwd_params.philox_offset1 = offset; @@ -625,6 +625,7 @@ void fused_attn_aotriton_bwd_qkvpacked( devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -787,6 +788,7 @@ void fused_attn_aotriton_bwd_kvpacked( devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -928,6 +930,7 @@ void fused_attn_aotriton_bwd( devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index 7818d2b33..3fdb359d1 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -47,7 +47,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, - int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -73,7 +73,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, - int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -100,7 +100,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, - int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, From 3a4fab8cff98475e8a314e609cf1037147954b61 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 10 Nov 2025 12:51:13 -0600 Subject: [PATCH 15/21] Updated causal window default --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 4526e9ae9..c3625af89 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -45,7 +45,7 @@ std::tuple get_window_sizes( ){ using aotriton::v3::flash::WindowValue; if(is_causal){ - return {WindowValue::BottomRightAligned, WindowValue::BottomRightAligned}; + return {WindowValue::TopLeftAligned, WindowValue::TopLeftAligned}; } return {-1, -1}; } From 917e3c305f02eb97f7b6b232c0cb06b26cca8c5e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 10 Nov 2025 12:53:55 -0600 Subject: [PATCH 16/21] Updated window values for causal --- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 72d0d28bf..3dbc68902 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -45,7 +45,7 @@ std::tuple get_window_sizes( ){ using aotriton::v3::flash::WindowValue; if(is_causal){ - return {WindowValue::BottomRightAligned, WindowValue::BottomRightAligned}; + return {WindowValue::TopLeftAligned, WindowValue::TopLeftAligned}; } return {-1, -1}; } From 36045c858b6a1d67d0a8bca6ac6d25f8c4ffc566 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 12 Nov 2025 14:57:56 -0600 Subject: [PATCH 17/21] Corrected DQ_ACC buffer, added env var for GPU kernel building --- setup.py | 2 ++ .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 91817d56e..b7c8be413 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,8 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}") if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") + elif int(os.getenv("NVTE_AOTRITON_BUILD_GPU_KERNELS", "0")): + cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS=ON") if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") else: diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index c3625af89..184b0b290 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -341,6 +341,7 @@ void fused_attn_aotriton_bwd_impl( std::array k_stride; std::array v_stride; std::array o_stride; + std::array dq_acc_stride; generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), @@ -349,6 +350,9 @@ void fused_attn_aotriton_bwd_impl( layout, NVTE_QKV_Matrix::NVTE_V_Matrix); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + // AOTriton expects a BSHD layout DQ_ACC matrix + generateMatrixStrides(b, h, s_q, s_kv, d, dq_acc_stride.data(), + NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Matrix::NVTE_Q_Matrix); //q and o are having the same shape //k and v are having the same shape @@ -375,7 +379,7 @@ void fused_attn_aotriton_bwd_impl( // auxilary tensors auto M_tensor = aotriton::TensorView<2>(reinterpret_cast(devPtrSoftmaxAux), m_shape, m_stride, aotriton::DType::kFloat32); auto delta_tensor = aotriton::TensorView<2>(reinterpret_cast(delta), m_shape, m_stride, aotriton::DType::kFloat32); - auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, q_stride, aotriton::DType::kFloat32); + auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, dq_acc_stride, aotriton::DType::kFloat32); NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream)); LazyTensorContext<4> dq_acc_ctx {.tensor_view = dq_acc_tensor}; @@ -479,7 +483,7 @@ void fused_attn_aotriton_bwd_impl( bwd_params.philox_seed_ptr = seed; bwd_params.philox_offset1 = offset; bwd_params.philox_offset2 = 0; - bwd_params.causal_type = is_causal; + bwd_params.causal_type = causal_type; bwd_params.varlen_type = varlen_type; bwd_params.window_left = window_left; bwd_params.window_right = window_right; From d6e46c1661f77fc2ea0d57aea86035bb339afc3d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 12 Nov 2025 16:57:46 -0600 Subject: [PATCH 18/21] Update AOTriton to 0.11.1b --- 3rdparty/aotriton | 2 +- transformer_engine/common/CMakeLists.txt | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/3rdparty/aotriton b/3rdparty/aotriton index 972223c50..98371989e 160000 --- a/3rdparty/aotriton +++ b/3rdparty/aotriton @@ -1 +1 @@ -Subproject commit 972223c501ffc22068bb035ac5d64cf54318d895 +Subproject commit 98371989e8a23267e284c94e95156a139e4b33c4 diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2f2e4fb42..d8ebb54cf 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -293,15 +293,15 @@ else() set(AOTRITON_NOIMAGE_MODE ON) endif() - set(__AOTRITON_VER "0.11b") + set(__AOTRITON_VER "0.11.1b") set(__AOTRITON_IMAGE_LIST "amd-gfx942" "amd-gfx950" ) set(__AOTRITON_IMAGE_SHA256_LIST - "3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942 - "27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950 - ) + "0a7bcee19d3bb6d548732248c3234f7b92736c2ab7a7aae65294b87a7fd64c06" # amd-gfx942 + "c1ba3bfe84217fd67df3dd1f8b67c80a7f7b33d0ad4d74b41d6567036e032ace" # amd-gfx950 + ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore include(ExternalProject) From 2bd900614ddc054526b92aaf3c9f88a519bb401b Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 25 Nov 2025 10:40:37 -0600 Subject: [PATCH 19/21] Added AOTriton commit SHA --- transformer_engine/common/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index ba1a558df..91387af16 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -295,6 +295,7 @@ else() set(AOTRITON_NOIMAGE_MODE ON) endif() + set(AOTRITON_CI_SUPPLIED_SHA1 "98371989e8a23267e284c94e95156a139e4b33c4") set(__AOTRITON_VER "0.11.1b") set(__AOTRITON_IMAGE_LIST "amd-gfx942" From 0fdff864e2f7304ff510d9c935ef97770d3307f8 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 26 Nov 2025 12:52:28 -0600 Subject: [PATCH 20/21] Moved handling of env variable to makefile --- setup.py | 2 -- transformer_engine/common/CMakeLists.txt | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 86f53b523..fd0b4a931 100644 --- a/setup.py +++ b/setup.py @@ -73,8 +73,6 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") else: os.environ["AOTRITON_CI_SUPPLIED_SHA1"] = "98371989e8a23267e284c94e95156a139e4b33c4" - if int(os.getenv("NVTE_AOTRITON_BUILD_GPU_KERNELS", "0")): - cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS=ON") if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index ba1a558df..084eaaa2b 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -8,7 +8,6 @@ cmake_minimum_required(VERSION 3.21) option(USE_ROCM "Use ROCm" ON) option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON) -option(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS "Build AOTriton GPU kernels" OFF) option(USE_FUSED_ATTN_CK "Use ck backend" ON) set(USE_CUDA OFF) @@ -289,7 +288,8 @@ else() if(NOT DEFINED AOTRITON_PATH) # If AOTRITON_PATH is not provided, we proceed to build the runtime # ourselves and either build or download the GPU kernels - if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) + set(AOTRITON_BUILD_GPU_KERNELS $ENV{NVTE_AOTRITON_BUILD_GPU_KERNELS}) + if(AOTRITON_BUILD_GPU_KERNELS) set(AOTRITON_NOIMAGE_MODE OFF) else() set(AOTRITON_NOIMAGE_MODE ON) @@ -372,7 +372,7 @@ else() aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) # Download GPU kernels if needed - if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) + if(NOT AOTRITON_BUILD_GPU_KERNELS) message(STATUS "Downloading AOTriton GPU Kernels.") foreach(image ${__AOTRITON_IMAGE_LIST}) string(SUBSTRING ${image} 7 -1 gfx_pattern) From 3f6e0545d1a315435634bb6432c0161c52da31bc Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 1 Dec 2025 17:17:01 +0000 Subject: [PATCH 21/21] Simplified lazy tensor implementation --- .../fused_attn_rocm/fused_attn_aotriton.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 184b0b290..bb3582afb 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -293,14 +293,10 @@ void fused_attn_aotriton_fwd_impl( } // A thin conversion wrapper around eager tensor-views to lazy tensors -template -struct LazyTensorContext { - aotriton::TensorView tensor_view; -}; template struct LazyTensorFunctions { static aotriton::TensorView acquire(void* cookie) { - return static_cast*>(cookie)->tensor_view; + return *static_cast*>(cookie); } static void dispose(void* cookie) { } @@ -382,15 +378,13 @@ void fused_attn_aotriton_bwd_impl( auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, dq_acc_stride, aotriton::DType::kFloat32); NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream)); - LazyTensorContext<4> dq_acc_ctx {.tensor_view = dq_acc_tensor}; - LazyTensorContext<2> delta_ctx {.tensor_view = delta_tensor}; auto dq_acc_lazy = aotriton::LazyTensor<4> { - .cookie = &dq_acc_ctx, + .cookie = &dq_acc_tensor, .acquire = &LazyTensorFunctions<4>::acquire, .dispose = &LazyTensorFunctions<4>::dispose }; auto delta_lazy = aotriton::LazyTensor<2> { - .cookie = &delta_ctx, + .cookie = &delta_tensor, .acquire = &LazyTensorFunctions<2>::acquire, .dispose = &LazyTensorFunctions<2>::dispose }; @@ -448,7 +442,7 @@ void fused_attn_aotriton_bwd_impl( std::cout<<"DQ: "<*>(delta_lazy.cookie)->tensor_view.data_ptr()<<"\n"; + std::cout<<"D: "<*>(dq_acc_lazy.cookie)->tensor_view.data_ptr()<<"\n"; + std::cout<<"DQ_ACC: "<