diff --git a/3rdparty/aotriton b/3rdparty/aotriton index 6fca155f4..98371989e 160000 --- a/3rdparty/aotriton +++ b/3rdparty/aotriton @@ -1 +1 @@ -Subproject commit 6fca155f4deeb8d9529326f7b69f350aeeb93477 +Subproject commit 98371989e8a23267e284c94e95156a139e4b33c4 diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index f70c9f8bb..d8ebb54cf 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -278,67 +278,116 @@ else() set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) if(USE_FUSED_ATTN_AOTRITON) - # This is for GPU kernel downloading # The AOTriton C++ runtime will be built from ../../3rdparty/aotriton # Hence there is no need to add multiple ROCM version here + set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton") set(__AOTRITON_SUFFIX "_TEprivate") + if(NOT DEFINED AOTRITON_PATH) - # Install aotriton fused attn + # If AOTRITON_PATH is not provided, we proceed to build the runtime + # ourselves and either build or download the GPU kernels if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) set(AOTRITON_NOIMAGE_MODE OFF) else() set(AOTRITON_NOIMAGE_MODE ON) endif() - string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") + set(__AOTRITON_VER "0.11.1b") + set(__AOTRITON_IMAGE_LIST + "amd-gfx942" + "amd-gfx950" + ) + set(__AOTRITON_IMAGE_SHA256_LIST + "0a7bcee19d3bb6d548732248c3234f7b92736c2ab7a7aae65294b87a7fd64c06" # amd-gfx942 + "c1ba3bfe84217fd67df3dd1f8b67c80a7f7b33d0ad4d74b41d6567036e032ace" # amd-gfx950 + ) + set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore include(ExternalProject) - ExternalProject_Add(aotriton_external - LIST_SEPARATOR "," - SOURCE_DIR ${TE}/3rdparty/aotriton - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} - -DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR} - -DGPU_TARGETS=${ARCH_LIST_COMMA_STR} - -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DAOTRITON_NO_PYTHON=ON - -DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX} - -DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so" + + # Download GPU kernels for a specific target + function(aotriton_download_image image project) + list(FIND __AOTRITON_IMAGE_LIST ${image} index) + list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_IMAGE_SHA256) + + string(CONCAT __AOTRITON_FILE + "aotriton-${__AOTRITON_VER}-images-" + "${image}.tar.gz") + string(CONCAT __AOTRITON_URL + "${__AOTRITON_BASE_URL}" + "${__AOTRITON_VER}/${__AOTRITON_FILE}") + + # Set up directories + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton/download/${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton/image/${image}) + + ExternalProject_Add(${project} + URL "${__AOTRITON_URL}" + URL_HASH SHA256=${__AOTRITON_IMAGE_SHA256} + DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} + SOURCE_DIR ${__AOTRITON_EXTRACT_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory + "${__AOTRITON_EXTRACT_DIR}" + "${__AOTRITON_INSTALL_DIR}" + BUILD_BYPRODUCTS + "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" + ) + message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.") + endfunction() + + # Build the AOTriton runtime from source with custom suffix to avoid + # potential conflict with libaotriton as provided by PyTorch + function(aotriton_build_from_source noimage) + message(STATUS "No-image mode: ${noimage}.") + ExternalProject_Add(aotriton_external + LIST_SEPARATOR "," + SOURCE_DIR ${TE}/3rdparty/aotriton + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} + -DAOTRITON_TARGET_ARCH=${ARCH_LIST_COMMA_STR} + -DGPU_TARGETS=${ARCH_LIST_COMMA_STR} + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DAOTRITON_NO_PYTHON=ON + -DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX} + -DAOTRITON_NOIMAGE_MODE=${noimage} + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so" ) + message(STATUS "Adding AOTriton library.") + add_dependencies(aotriton aotriton_external) + target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so) + target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) + install(DIRECTORY + ${__AOTRITON_INSTALL_DIR}/lib + DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine + PATTERN "cmake" EXCLUDE + ) + endfunction() + add_library(aotriton INTERFACE) - add_dependencies(aotriton aotriton_external) - target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so) - target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) + message(STATUS "Building AOTriton from source.") + string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") + aotriton_build_from_source(${AOTRITON_NOIMAGE_MODE}) + + # Download GPU kernels if needed if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS) - set(__AOTRITON_VER "0.10b") - set(__AOTRITON_SHA256 "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b") - string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" - "${__AOTRITON_VER}/aotriton-" - "${__AOTRITON_VER}-manylinux_2_28" - "_x86_64-rocm7.0" - "-shared.tar.gz") - set(aotriton_image_dirs) - foreach(X IN LISTS CMAKE_HIP_ARCHITECTURES) - list(APPEND aotriton_image_dirs "${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images/amd-${X}") + message(STATUS "Downloading AOTriton GPU Kernels.") + foreach(image ${__AOTRITON_IMAGE_LIST}) + string(SUBSTRING ${image} 7 -1 gfx_pattern) + string(REPLACE "x" "." gfx_regex ${gfx_pattern}) + foreach(target ${ARCH_LIST_COMMA_STR}) + if(target MATCHES ${gfx_regex}) + message(STATUS "Downloading AOTriton image ${image}.") + set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern}) + aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET}) + add_dependencies(aotriton_external ${__AOTRITON_DOWNLOAD_TARGET}) + break() + endif() + endforeach() endforeach() - set(aotriton_lib_install_dir "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images") - file(REMOVE_RECURSE ${aotriton_lib_install_dir}) - file(MAKE_DIRECTORY ${aotriton_lib_install_dir}) - ExternalProject_Add(aotriton_images - URL "${__AOTRITON_URL}" - URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - BUILD_ALWAYS TRUE - INSTALL_COMMAND cp -Ra ${aotriton_image_dirs} ${aotriton_lib_install_dir}) - add_dependencies(aotriton aotriton_images) + else() endif() - install(DIRECTORY - ${__AOTRITON_INSTALL_DIR}/lib - DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine - PATTERN "cmake" EXCLUDE - PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE) + else() # Use aotriton built during initial TE building/installation # When only need rebuild TE library itself diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 66fa72c0c..392a73bfa 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -401,7 +401,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){ fused_attn_aotriton_fwd_qkvpacked( b, h, max_seqlen, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + is_training, attn_scale, dropout, + window_size_left, window_size_right, + qkv_layout, bias_type, attn_mask_type, input_QKV, output_O, Aux_CTX_Tensors, input_cu_seqlens, @@ -576,6 +578,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const fused_attn_aotriton_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, output_O, Aux_CTX_Tensors, @@ -759,6 +762,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_aotriton_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, output_O, Aux_CTX_Tensors, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 3fe7ec854..3dbc68902 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -37,6 +37,19 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const uint64_t* ptr) namespace transformer_engine { namespace fused_attn_rocm { +// TODO: Support SWA +std::tuple get_window_sizes( + int32_t window_size_left, + int32_t window_size_right, + bool is_causal +){ + using aotriton::v3::flash::WindowValue; + if(is_causal){ + return {WindowValue::TopLeftAligned, WindowValue::TopLeftAligned}; + } + return {-1, -1}; +} + // check the fused attn config to see whether it's aotriton backend supported bool is_aotriton_backend_supported( NVTEDType q_dtype, @@ -127,12 +140,12 @@ aotriton::DType nvte_to_aotriton_dtype(DType t_dtype){ void fused_attn_aotriton_fwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrSoftmaxAux, void *devPtrO, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, - //void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, + void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, aotriton::DType dtype, void *workspace, size_t *workspace_size, @@ -163,6 +176,11 @@ void fused_attn_aotriton_fwd_impl( auto k_tensor = aotriton::TensorView<4>(reinterpret_cast(devPtrK), kv_shape, k_stride, dtype); auto v_tensor = aotriton::TensorView<4>(reinterpret_cast(devPtrV), kv_shape, v_stride, dtype); + // Cumulative seqlen tensors + std::array cu_seqlens_shape{b+1}; + std::array cu_seqlens_stride{1}; + auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); + auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); std::array o_stride; generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), @@ -185,6 +203,22 @@ void fused_attn_aotriton_fwd_impl( if (env_p != nullptr && std::string(env_p) == "1") nvte_log_aotriton_config = true; } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); + using aotriton::v3::flash::attn_fwd; + auto seed = mk_aoscalartensor(devPtrDropoutSeed); + auto offset1 = mk_aoscalartensor(devPtrDropoutOffset); + auto seed_output = mk_aoscalartensor(nullptr); + auto offset_output = mk_aoscalartensor(nullptr); + const auto is_causal = mask_type == NVTE_CAUSAL_MASK; + aotriton::TensorView<0> atomic_for_causal(reinterpret_cast(workspace), aotriton::DType::kInt32); + + using aotriton::v3::flash::VarlenType; + int8_t varlen_type = VarlenType::None; + + auto [window_left, window_right] = get_window_sizes(window_size_left, window_size_right, is_causal); + using aotriton::v3::flash::CausalType; + int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (nvte_log_aotriton_config) { std::cout< empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v2::flash::attn_fwd; - auto seed = mk_aoscalartensor(devPtrDropoutSeed); - auto offset1 = mk_aoscalartensor(devPtrDropoutOffset); - auto offset2 = 0; - auto seed_output = mk_aoscalartensor(nullptr); - auto offset_output = mk_aoscalartensor(nullptr); - const auto is_causal = mask_type == NVTE_CAUSAL_MASK; - aotriton::TensorView<0> atomic_for_causal(reinterpret_cast(workspace), aotriton::DType::kInt32); + + aotriton::v3::flash::attn_fwd_params fwd_params{}; + fwd_params.Q = q_tensor; + fwd_params.K = k_tensor; + fwd_params.V = v_tensor; + // fwd_params.B = empty_bias; + // fwd_params.A = nullptr; // Alibi slopes, currently unused + fwd_params.Sm_scale = scaling_factor; + fwd_params.L = M_tensor; + fwd_params.Out = o_tensor; + if(varlen_type){ + fwd_params.cu_seqlens_q = cu_seqlens_q; + fwd_params.cu_seqlens_k = cu_seqlens_k; + fwd_params.Max_seqlen_q = s_q; // Unused if cu_seqlens_q is empty + fwd_params.Max_seqlen_k = s_kv; // Unused if cu_seqlens_k is empty + } + fwd_params.dropout_p = is_training? dropout_probability : 0; + fwd_params.philox_seed_ptr = seed; + fwd_params.philox_offset1 = offset1; + fwd_params.philox_offset2 = 0; + fwd_params.philox_seed_output = seed_output; + fwd_params.philox_offset_output = offset_output; + fwd_params.encoded_softmax = encoded_softmax_tensor; + fwd_params.persistent_atomic_counter = atomic_for_causal; + fwd_params.causal_type = causal_type; + fwd_params.varlen_type = varlen_type; + fwd_params.window_left = window_left; + fwd_params.window_right = window_right; + NVTE_CHECK_CUDA(hipMemsetAsync(workspace, 0, sizeof(int32_t), stream)); - NVTE_CHECK_CUDA(attn_fwd(q_tensor, - k_tensor, - v_tensor, - empty_bias, - scaling_factor, - M_tensor, - o_tensor, - is_training? dropout_probability : 0, - seed, - offset1, - offset2, - seed_output, - offset_output, - encoded_softmax_tensor, - is_causal, - atomic_for_causal, - stream)); + NVTE_CHECK_CUDA(attn_fwd(fwd_params, fwd_params.kVersion, stream)); } void fused_attn_aotriton_bwd_impl( @@ -343,6 +405,7 @@ using namespace transformer_engine::fused_attn_rocm; void fused_attn_aotriton_fwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, bool is_training, float attn_scale, float dropout, + int32_t window_left, int32_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -394,12 +457,14 @@ void fused_attn_aotriton_fwd_qkvpacked( fused_attn_aotriton_fwd_impl( b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + window_left, window_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, + input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, nvte_to_aotriton_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -460,7 +525,6 @@ void fused_attn_aotriton_bwd_qkvpacked( void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); size_t workspace_size = 0; - fused_attn_aotriton_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, @@ -498,6 +562,7 @@ void fused_attn_aotriton_bwd_qkvpacked( void fused_attn_aotriton_fwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + int32_t window_left, int32_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -549,12 +614,14 @@ void fused_attn_aotriton_fwd_kvpacked( fused_attn_aotriton_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + window_left, window_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, nvte_to_aotriton_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -618,7 +685,6 @@ void fused_attn_aotriton_bwd_kvpacked( void *devPtrSoftmaxStats = output_S->data.dptr; size_t workspace_size = 0; - fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, @@ -656,6 +722,7 @@ void fused_attn_aotriton_bwd_kvpacked( void fused_attn_aotriton_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + int32_t window_left, int32_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -698,12 +765,14 @@ void fused_attn_aotriton_fwd( fused_attn_aotriton_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + window_left, window_right, qkv_layout, bias_type, attn_mask_type, - devPtrQ, devPtrK, devPtrV, + devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, nvte_to_aotriton_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -755,7 +824,6 @@ void fused_attn_aotriton_bwd( void *devPtrSoftmaxStats = output_S->data.dptr; size_t workspace_size = 0; - fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index 5b9ef89a1..b016acc67 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -35,6 +35,7 @@ bool is_aotriton_backend_supported( void fused_attn_aotriton_fwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, bool is_training, float attn_scale, float dropout, + int32_t window_left, int32_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -58,6 +59,7 @@ void fused_attn_aotriton_bwd_qkvpacked( void fused_attn_aotriton_fwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + int32_t window_left, int32_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, @@ -83,6 +85,7 @@ void fused_attn_aotriton_bwd_kvpacked( void fused_attn_aotriton_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, + int32_t window_left, int32_t window_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors,