From e90b99164b037c0f5439dfe37945e69f54e76c98 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 16 Oct 2025 15:54:11 +0000 Subject: [PATCH 01/30] [ROCm] manually pick up fwd native padding support from Meekail's PR --- .../include/ck_fused_attn/ck_fused_attn.hpp | 4 + .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 50 +++++++---- .../common/fused_attn_rocm/fused_attn_ck.cpp | 82 +++++++++++++++---- 3 files changed, 104 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index d174a48f4..1eef86309 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -59,6 +59,7 @@ hipError_t ck_attn_fwd( uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_ptr, bool uses_fwd_v3, + int how_v3_bf16_cvt, hipStream_t stream); hipError_t ck_attn_varlen_fwd( @@ -72,6 +73,7 @@ hipError_t ck_attn_varlen_fwd( const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, bool is_training, float scaling_factor, float dropout_probability, @@ -82,6 +84,8 @@ hipError_t ck_attn_varlen_fwd( uint64_t stride_h_o, uint64_t stride_s_o, void* lse_thd_ptr, bool uses_fwd_v3, + int how_v3_bf16_cvt, + bool is_v3_api_check, hipStream_t stream); hipError_t ck_attn_bwd( diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index c87a3db6c..fc2f14f06 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -27,6 +27,9 @@ void log_fwd_config(const char* func_name, const bool is_v_rowmajor, const bool do_fp8_static_quant, const bool uses_fwd_v3, + const bool how_v3_bf16_cvt, + const void* cu_seqlen_q_padded_ptr, + const void* cu_seqlen_kv_padded_ptr, const fmha_fwd_args& fmha_args){ bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { @@ -50,6 +53,9 @@ void log_fwd_config(const char* func_name, std::cout<<"has_dropout: "< 0.f); @@ -258,7 +265,7 @@ hipError_t ck_attn_fwd( }(); // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, fmha_args); + log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, nullptr, nullptr, fmha_args); if (uses_fwd_v3) { set_aiter_asm_dir(); @@ -271,7 +278,11 @@ hipError_t ck_attn_fwd( mask_type, bias_type, has_lse, - uses_fwd_v3); + uses_fwd_v3, + how_v3_bf16_cvt, + nullptr, //cu_seqlen_q_padded + nullptr, //cu_seqlen_kv_padded + false); if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); @@ -290,6 +301,7 @@ hipError_t ck_attn_varlen_fwd( const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, bool is_training, float scaling_factor, float dropout_probability, @@ -300,6 +312,8 @@ hipError_t ck_attn_varlen_fwd( uint64_t stride_h_o, uint64_t stride_s_o, void* lse_thd_ptr, bool uses_fwd_v3, + int how_v3_bf16_cvt, + bool is_v3_api_check, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); @@ -389,8 +403,8 @@ hipError_t ck_attn_varlen_fwd( cu_seqlen_q_ptr, //seqstart_q_ptr cu_seqlen_kv_ptr, //seqstart_k_ptr nullptr, //seqlen_k_ptr - nullptr, //seqstart_padded_q_ptr - nullptr, //seqstart_padded_k_ptr + cu_seqlen_q_padded_ptr, //seqstart_padded_q_ptr + cu_seqlen_kv_padded_ptr, //seqstart_padded_k_ptr max_seqlen_q, //seqlen_q, unused in group mode max_seqlen_kv, //seqlen_kv, unused in group mode batch, @@ -433,21 +447,29 @@ hipError_t ck_attn_varlen_fwd( }(); // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, fmha_args); + log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args); if (uses_fwd_v3) { set_aiter_asm_dir(); } - float average_runtime = aiter::mha_fwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - uses_fwd_v3); - if(average_runtime < 0){ + float average_runtime_or_v3_check_status = aiter::mha_fwd( + fmha_args, + stream_config, + data_type_str, + is_group_mode, + mask_type, + bias_type, + has_lse, + uses_fwd_v3, + how_v3_bf16_cvt, + uses_fwd_v3? cu_seqlen_q_padded_ptr: nullptr, + uses_fwd_v3? cu_seqlen_kv_padded_ptr: nullptr, + is_v3_api_check); + if(is_v3_api_check){ + return (hipError_t)(average_runtime_or_v3_check_status > 0); + } + if(average_runtime_or_v3_check_status < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index b38249f5b..b6a05e9a8 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -536,7 +536,7 @@ void remove_padding_softmax_lse( // actual fwd implementation, calling ck api directly void fused_attn_ck_fwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - bool pad_between_seqs, size_t max_tokens_q, size_t max_tokens_kv, + size_t max_tokens_q, size_t max_tokens_kv, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -556,10 +556,18 @@ void fused_attn_ck_fwd_impl( if (env_p != nullptr && std::string(env_p) == "1") nvte_log_ck_config = true; } + bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0); + int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool is_batch = (nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD || + nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD); + bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case size_t q_storage_bytes = max_tokens_q*h*d_qk*nvte_dtype_size(dtype); @@ -573,7 +581,7 @@ void fused_attn_ck_fwd_impl( if(bias_type == NVTE_Bias_Type::NVTE_ALIBI){ (*workspace_size)+= h*sizeof(float); } - if(pad_between_seqs){ + if(is_batch && is_padding){ // softmax_lse buffer (*workspace_size)+= max_tokens_q*h*sizeof(float); // request q, k, v, o buffer without padding @@ -621,7 +629,7 @@ void fused_attn_ck_fwd_impl( void* devPtrKWithoutPadding = nullptr; void* devPtrVWithoutPadding = nullptr; void* devPtrOWithoutPadding = nullptr; - if(pad_between_seqs){ + if(is_batch && is_padding){ // next h*max_tokens_q*sizeof(float) in workspace are for lse buffer devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); @@ -673,7 +681,8 @@ void fused_attn_ck_fwd_impl( std::cout<<"v_stride: ("< Date: Thu, 16 Oct 2025 16:09:04 -0500 Subject: [PATCH 02/30] Initial update --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 97 +++++++++++++------ 1 file changed, 69 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index b6a05e9a8..7e9db8d0a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -9,6 +9,7 @@ #include // Required for std::accumulate #ifdef USE_FUSED_ATTN_CK #include +#include "ck_tile/host.hpp" #endif // USE_FUSED_ATTN_CK #include "../util/cuda_runtime.h" #include "../util/system.h" @@ -215,6 +216,18 @@ ck_fused_attn::MaskType set_ck_mask(NVTE_Mask_Type nvte_mask_type, int64_t nvte_ return ck_fused_attn::MaskType::window_generic; } +__global__ +void generate_cu_seqlen_padded( + uint32_t s_q, uint32_t s_kv, uint32_t b, + ck_tile::index_t* cu_seqlen_q_padded_ptr, + ck_tile::index_t* cu_seqlen_kv_padded_ptr +){ + for(int i = blockIdx.x * blockDim.x + threadIdx.x; i < b+1; i += blockDim.x * gridDim.x){ + cu_seqlen_q_padded_ptr[i] = s_q * i; + cu_seqlen_kv_padded_ptr[i] = s_kv * i; + } +} + __global__ void generate_alibi_slope(uint64_t h, float* alibi_slope_ptr){ for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < h; id += blockDim.x * gridDim.x){ @@ -561,9 +574,10 @@ void fused_attn_ck_fwd_impl( int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool is_SBHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD; + bool is_BSHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD + bool is_batch = is_BSHD || is_SBHD; - bool is_batch = (nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD || - nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD); bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); @@ -582,10 +596,14 @@ void fused_attn_ck_fwd_impl( (*workspace_size)+= h*sizeof(float); } if(is_batch && is_padding){ + // cu_seqlen_padded buffers + (*workspace_size)+= 2*(b+1)*sizeof(float); // softmax_lse buffer (*workspace_size)+= max_tokens_q*h*sizeof(float); - // request q, k, v, o buffer without padding - (*workspace_size)+= q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes; + if(is_SBHD){ + // request q, k, v, o buffer without padding + (*workspace_size)+= q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes; + } }else if(is_ragged){ // We include a softmax_lse buffer to use the kernel in order to properly reshape the lse as needed. (*workspace_size)+= max_tokens_q*h*sizeof(float); @@ -633,28 +651,36 @@ void fused_attn_ck_fwd_impl( // next h*max_tokens_q*sizeof(float) in workspace are for lse buffer devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); - //determine q, k ,v buffer based on the workspace next ptr and layout group - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); - //Q ptr always comes at first - devPtrQWithoutPadding = workspace_next; - if(layout_group==NVTE_QKV_Layout_Group::NVTE_3HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_H3D){ - //keep the start address difference the same among q, k, and v - devPtrKWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrK) - static_cast(devPtrQ))); - devPtrVWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrQ))); - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes + k_storage_bytes + v_storage_bytes); - }else if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); - //keep the start address difference the same between k and v - devPtrKWithoutPadding = workspace_next; - devPtrVWithoutPadding = static_cast(static_cast(devPtrKWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrK))); - workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes + v_storage_bytes); + if(is_SBHD){ + //determine q, k ,v buffer based on the workspace next ptr and layout group + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + //Q ptr always comes at first + devPtrQWithoutPadding = workspace_next; + if(layout_group==NVTE_QKV_Layout_Group::NVTE_3HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_H3D){ + //keep the start address difference the same among q, k, and v + devPtrKWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrK) - static_cast(devPtrQ))); + devPtrVWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrQ))); + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes + k_storage_bytes + v_storage_bytes); + }else if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); + //keep the start address difference the same between k and v + devPtrKWithoutPadding = workspace_next; + devPtrVWithoutPadding = static_cast(static_cast(devPtrKWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrK))); + workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes + v_storage_bytes); + }else{ + //qkv separated + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); + devPtrKWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes); + devPtrVWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + v_storage_bytes); + } }else{ - //qkv separated - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); - devPtrKWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes); - devPtrVWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + v_storage_bytes); + // cu_seqlen_padded ptrs for THD conversion + devPtrSeqOffsetsQ = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(float)); + devPtrSeqOffsetsKV = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(float)); } //determine the o buffer based on workspace next section devPtrOWithoutPadding = workspace_next; @@ -700,7 +726,22 @@ void fused_attn_ck_fwd_impl( std::cout<<"window_size: ("<>>( + s_q, s_kv, b, + static_cast(devPtrSeqOffsetsQ), + static_cast(devPtrSeqOffsetsKV) + ); + is_ragged=true; + if(nvte_log_ck_config){ + std::cout << "\nConverting BSHD to THD\n"; + } + } + if(is_SBHD && is_padding){ // remove padding for q, k, v remove_padding(dtype, b, h, s_q, d_qk, max_tokens_q, false, q_stride[0], q_stride[1], q_stride[2], devPtrQ, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrQWithoutPadding, stream); remove_padding(dtype, b, hg, s_kv, d_qk, max_tokens_kv, false, k_stride[0], k_stride[1], k_stride[2], devPtrK, devPtrCuSeqlensKV, devPtrSeqOffsetsKV, devPtrKWithoutPadding, stream); @@ -726,7 +767,7 @@ void fused_attn_ck_fwd_impl( set_ck_mask(mask_type, window_size_left, window_size_right), window_size_left, window_size_right, devPtrOWithoutPadding, - o_stride[1], (is_ragged? o_stride[2] : std::min(o_stride[0], o_stride[2])), + o_stride[1], std::min(o_stride[0], o_stride[2]), devPtrSoftmaxLSEWithoutPadding, nvte_ck_uses_fwd_v3, nvte_ck_how_v3_bf16_cvt, @@ -785,7 +826,7 @@ void fused_attn_ck_fwd_impl( nvte_ck_how_v3_bf16_cvt, false, stream)); - if(is_v3_supported){ + if(nvte_ck_uses_fwd_v3 && is_v3_supported){ // aiter asm output softmax_lse with padding add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else{ From 81bac3520d8a45ec2a4a9496a3d91c244d2838d8 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 16 Oct 2025 16:19:27 -0500 Subject: [PATCH 03/30] Updated stride --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7e9db8d0a..f0f289e67 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -575,7 +575,7 @@ void fused_attn_ck_fwd_impl( bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; bool is_SBHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD; - bool is_BSHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD + bool is_BSHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD; bool is_batch = is_BSHD || is_SBHD; bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || @@ -755,11 +755,11 @@ void fused_attn_ck_fwd_impl( b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, devPtrQWithoutPadding, - q_stride[1], std::min(q_stride[0], q_stride[2]), + q_stride[1], q_stride[0], devPtrKWithoutPadding, - k_stride[1], std::min(k_stride[0], k_stride[2]), + k_stride[1], k_stride[0], devPtrVWithoutPadding, - v_stride[1], std::min(v_stride[0], v_stride[2]), + v_stride[1], v_stride[0], devPtrCuSeqlensQ, devPtrCuSeqlensKV, nullptr, nullptr, //cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr is_training, scaling_factor, dropout_probability, @@ -767,7 +767,7 @@ void fused_attn_ck_fwd_impl( set_ck_mask(mask_type, window_size_left, window_size_right), window_size_left, window_size_right, devPtrOWithoutPadding, - o_stride[1], std::min(o_stride[0], o_stride[2]), + o_stride[1], o_stride[0], devPtrSoftmaxLSEWithoutPadding, nvte_ck_uses_fwd_v3, nvte_ck_how_v3_bf16_cvt, From 54ee86af6052ee73d3414066e068f808c9908d83 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 16 Oct 2025 16:50:59 -0500 Subject: [PATCH 04/30] Corrected typing in allocation portions --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index f0f289e67..ba898d6e0 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -596,13 +596,14 @@ void fused_attn_ck_fwd_impl( (*workspace_size)+= h*sizeof(float); } if(is_batch && is_padding){ - // cu_seqlen_padded buffers - (*workspace_size)+= 2*(b+1)*sizeof(float); // softmax_lse buffer (*workspace_size)+= max_tokens_q*h*sizeof(float); if(is_SBHD){ // request q, k, v, o buffer without padding (*workspace_size)+= q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes; + }else{ + // cu_seqlen_padded buffers + (*workspace_size)+= 2*(b+1)*sizeof(ck_tile::index_t); } }else if(is_ragged){ // We include a softmax_lse buffer to use the kernel in order to properly reshape the lse as needed. @@ -678,9 +679,9 @@ void fused_attn_ck_fwd_impl( }else{ // cu_seqlen_padded ptrs for THD conversion devPtrSeqOffsetsQ = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(float)); + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(ck_tile::index_t)); devPtrSeqOffsetsKV = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(float)); + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(ck_tile::index_t)); } //determine the o buffer based on workspace next section devPtrOWithoutPadding = workspace_next; From 47a7cab92bd0c8da1756bc7de307353d3b75befc Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 17 Oct 2025 10:20:17 -0500 Subject: [PATCH 05/30] Applied Ye's patch --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index ba898d6e0..7188c4df0 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -9,7 +9,6 @@ #include // Required for std::accumulate #ifdef USE_FUSED_ATTN_CK #include -#include "ck_tile/host.hpp" #endif // USE_FUSED_ATTN_CK #include "../util/cuda_runtime.h" #include "../util/system.h" @@ -219,8 +218,8 @@ ck_fused_attn::MaskType set_ck_mask(NVTE_Mask_Type nvte_mask_type, int64_t nvte_ __global__ void generate_cu_seqlen_padded( uint32_t s_q, uint32_t s_kv, uint32_t b, - ck_tile::index_t* cu_seqlen_q_padded_ptr, - ck_tile::index_t* cu_seqlen_kv_padded_ptr + int32_t* cu_seqlen_q_padded_ptr, + int32_t* cu_seqlen_kv_padded_ptr ){ for(int i = blockIdx.x * blockDim.x + threadIdx.x; i < b+1; i += blockDim.x * gridDim.x){ cu_seqlen_q_padded_ptr[i] = s_q * i; @@ -603,7 +602,7 @@ void fused_attn_ck_fwd_impl( (*workspace_size)+= q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes; }else{ // cu_seqlen_padded buffers - (*workspace_size)+= 2*(b+1)*sizeof(ck_tile::index_t); + (*workspace_size)+= 2*(b+1)*sizeof(int32_t); } }else if(is_ragged){ // We include a softmax_lse buffer to use the kernel in order to properly reshape the lse as needed. @@ -679,9 +678,9 @@ void fused_attn_ck_fwd_impl( }else{ // cu_seqlen_padded ptrs for THD conversion devPtrSeqOffsetsQ = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(ck_tile::index_t)); + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); devPtrSeqOffsetsKV = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(ck_tile::index_t)); + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); } //determine the o buffer based on workspace next section devPtrOWithoutPadding = workspace_next; @@ -734,10 +733,9 @@ void fused_attn_ck_fwd_impl( dim3 grid(ceil(1.0 * (b+1)/THREADS_PER_BLOCK)); generate_cu_seqlen_padded<<>>( s_q, s_kv, b, - static_cast(devPtrSeqOffsetsQ), - static_cast(devPtrSeqOffsetsKV) + static_cast(devPtrSeqOffsetsQ), + static_cast(devPtrSeqOffsetsKV) ); - is_ragged=true; if(nvte_log_ck_config){ std::cout << "\nConverting BSHD to THD\n"; } @@ -777,7 +775,7 @@ void fused_attn_ck_fwd_impl( // add padding for o and softmax_lse add_padding(dtype, b, h, s_q, d_v, max_tokens_q, false, o_stride[0], o_stride[1], o_stride[2], devPtrOWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrO, stream); add_padding_softmax_lse(b, h, s_q, max_tokens_q, false, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); - }else if(is_ragged){ + }else if((is_BSHD && is_padding) || is_ragged){ using ck_fused_attn::ck_attn_varlen_fwd; // TODO: remove the v3 api check after ck align softmax_lse with aiter asm bool is_v3_supported = ck_attn_varlen_fwd( @@ -829,10 +827,10 @@ void fused_attn_ck_fwd_impl( stream)); if(nvte_ck_uses_fwd_v3 && is_v3_supported){ // aiter asm output softmax_lse with padding - add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); + add_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else{ // ck v2 output softmax_lse without padding - add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); + add_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); } }else{ using ck_fused_attn::ck_attn_fwd; From 0e0064fb81f35f9b302021f02c375e36645a59e0 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 20 Oct 2025 16:50:41 +0000 Subject: [PATCH 06/30] [ROCm] manually pick Meekail's PR to support native padding for bwd --- 3rdparty/aiter | 2 +- .../include/ck_fused_attn/ck_fused_attn.hpp | 3 + .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 124 +++++-- .../common/fused_attn_rocm/fused_attn_ck.cpp | 345 +++++++++++++----- 4 files changed, 345 insertions(+), 129 deletions(-) diff --git a/3rdparty/aiter b/3rdparty/aiter index 74e71eb8e..963b86542 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit 74e71eb8ee8a663d5e33c0cfd8b4dad7708ae84b +Subproject commit 963b86542014fbed69c17aaa4fb46b6d69ab3b9c diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 1eef86309..c3a8c866c 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -141,6 +141,8 @@ hipError_t ck_attn_varlen_bwd( const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, + const void* seqlen_q_ptr, const void* seqlen_kv_ptr, const void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, const void* lse_thd_ptr, @@ -166,6 +168,7 @@ hipError_t ck_attn_varlen_bwd( bool uses_bwd_v3, bool is_v3_atomic_fp32, int how_v3_bf16_cvt, + bool is_v3_api_check, hipStream_t stream); }//namespace ck_fused_attn diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 2b717ace0..a7654e634 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -15,6 +15,24 @@ namespace ck_fused_attn{ +// TODO: unify with binary search in TE/common/fused_attn(rocm)/util +// no device std::upper_bound +// in an increasing array with given size len, search for the index that: +// array[index] <= target < array[index+1] +// guaranteed that target >=0 and target <= cu_seqlen[end-1] +__forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + // define dk_dv_reduce function only for fp16 and bf16 types template __global__ void dk_dv_reduce( @@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce( // define dk_dv_reduce function in THD layout only for fp16 and bf16 types template __global__ void dk_dv_reduce_thd( - uint64_t h, uint64_t hg, uint64_t d, - const int32_t* total_seqlen_kv_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t d, + const int32_t* cu_seqlen_kv_ptr, + const int32_t* cu_seqlen_kv_padded_ptr, const DataType *dk_expanded, const DataType *dv_expanded, uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded, @@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd( uint64_t hdim_idx = threadIdx.x; assert(hdim_idx= *total_seqlen_kv_ptr){ + + if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } - + if(cu_seqlen_kv_padded_ptr){ + uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); + uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; + if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ + return; + } + } // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd( // When d_qk != d_v, we need to reduce dk and dv separately template __global__ void dk_or_dv_reduce_thd( - uint64_t h, uint64_t hg, uint64_t d, - const int32_t* total_seqlen_kv_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t d, + const int32_t* cu_seqlen_kv_ptr, + const int32_t* cu_seqlen_kv_padded_ptr, const DataType *dk_or_dv_expanded, uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded, DataType *dk_or_dv, @@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd( assert(hdim_idx= *total_seqlen_kv_ptr){ + if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } - + if(cu_seqlen_kv_padded_ptr){ + uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); + uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; + if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ + return; + } + } // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -312,6 +344,8 @@ void log_bwd_config(const char* func_name, const bool uses_bwd_v3, const bool is_v3_atomic_fp32, const int how_v3_bf16_cvt, + const void* cu_seqlen_q_padded_ptr, + const void* cu_seqlen_kv_padded_ptr, const fmha_bwd_args& fmha_args){ bool ck_fused_attn_log_config = false; @@ -337,6 +371,8 @@ void log_bwd_config(const char* func_name, std::cout<<"uses_bwd_v3: "< 0.f); @@ -919,7 +963,8 @@ hipError_t ck_attn_varlen_bwd( dq_acc_ptr, //dq_acc_buf cu_seqlen_q_ptr,//cu_seqlen_q cu_seqlen_kv_ptr,//cu_seqlen_kv - nullptr, /* seqlen_k_ptr */ + seqlen_q_ptr, /* seqlen_q_ptr */ + seqlen_kv_ptr, /* seqlen_k_ptr */ max_seqlen_q, //seqlen_q, unused in group mode max_seqlen_k, //seqlen_kv, unused in group mode batch, @@ -978,25 +1023,31 @@ hipError_t ck_attn_varlen_bwd( }(); // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args); if (uses_bwd_v3) { set_aiter_asm_dir(); } - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_enum::no_bias, - has_dbias, - s_randval, - deterministic, - uses_bwd_v3, - is_v3_atomic_fp32, - how_v3_bf16_cvt); - if(average_runtime < 0){ + float average_runtime_or_v3_check_status = aiter::mha_bwd(fmha_args, + stream_config, + data_type_str, + is_group_mode, + mask_type, + bias_enum::no_bias, + has_dbias, + s_randval, + deterministic, + uses_bwd_v3, + is_v3_atomic_fp32, + how_v3_bf16_cvt, + uses_bwd_v3? cu_seqlen_q_padded_ptr: nullptr, + uses_bwd_v3? cu_seqlen_kv_padded_ptr: nullptr, + is_v3_api_check); + if(is_v3_api_check){ + return (hipError_t)(average_runtime_or_v3_check_status > 0); + } + if(average_runtime_or_v3_check_status < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); } @@ -1006,6 +1057,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block(d_qk); if (ck_fused_attn_log_config){ std::cout<, grid, block, 0, stream, - h, hg, d_qk, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_qk, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dk_expanded_ptr), static_cast(dv_expanded_ptr), stride_h_dk_expanded, stride_s_dk_expanded, @@ -1030,6 +1084,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block_dk(d_qk); if (ck_fused_attn_log_config){ std::cout<, grid, block_dk, 0, stream, - h, hg, d_qk, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_qk, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dk_expanded_ptr), stride_h_dk_expanded, stride_s_dk_expanded, static_cast(dk_ptr), @@ -1050,6 +1107,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block_dv(d_v); if (ck_fused_attn_log_config){ std::cout<, grid, block_dv, 0, stream, - h, hg, d_v, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_v, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dv_expanded_ptr), stride_h_dv_expanded, stride_s_dv_expanded, static_cast(dv_ptr), diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index b6a05e9a8..dd80a6c99 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -737,7 +737,6 @@ void fused_attn_ck_fwd_impl( add_padding_softmax_lse(b, h, s_q, max_tokens_q, false, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else if(is_ragged){ using ck_fused_attn::ck_attn_varlen_fwd; - // TODO: remove the v3 api check after ck align softmax_lse with aiter asm bool is_v3_supported = ck_attn_varlen_fwd( nvte_to_ck_dtype(dtype), b, h, hg, s_q, s_kv, d_qk, d_v, @@ -785,13 +784,8 @@ void fused_attn_ck_fwd_impl( nvte_ck_how_v3_bf16_cvt, false, stream)); - if(is_v3_supported){ - // aiter asm output softmax_lse with padding - add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); - }else{ - // ck v2 output softmax_lse without padding - add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); - } + // aiter asm output softmax_lse with padding + add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else{ using ck_fused_attn::ck_attn_fwd; NVTE_CHECK_CUDA( @@ -820,6 +814,116 @@ void fused_attn_ck_fwd_impl( } } +// TODO: remove v3 api checking after ck v2 fully support native padding +bool is_ck_attn_bwd_varlen_v3_supported( + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, + size_t max_tokens_q, size_t max_tokens_kv, + float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, + bool deterministic, + DType dtype, + bool nvte_ck_uses_bwd_v3, + bool nvte_ck_is_v3_atomic_fp32, + int nvte_ck_how_v3_bf16_cvt, + cudaStream_t stream){ + + bool is_mqa_gqa = (h > hg); + bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + + std::array q_stride; + std::array k_stride; + std::array v_stride; + std::array o_stride; + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + + std::array dk_expanded_stride; + std::array dv_expanded_stride; + if(is_mqa_gqa){ + generateMatrixStrides(b, h, s_q, s_kv, d_qk, dk_expanded_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d_v, dv_expanded_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + } + return ck_attn_varlen_bwd( + nvte_to_ck_dtype(dtype), + b, h, hg, s_q, s_kv, d_qk, d_v, + max_tokens_q, max_tokens_kv, + nullptr, + q_stride[1], (is_ragged? q_stride[2] : std::min(q_stride[0], q_stride[2])), + nullptr, + k_stride[1], (is_ragged? k_stride[2] : std::min(k_stride[0], k_stride[2])), + nullptr, + v_stride[1], (is_ragged? v_stride[2] : std::min(v_stride[0], v_stride[2])), + nullptr, nullptr, + nullptr, nullptr, + nullptr, nullptr, + nullptr, + o_stride[1], (is_ragged? o_stride[2] : std::min(o_stride[0], o_stride[2])), + nullptr, + nullptr, + o_stride[1], (is_ragged? o_stride[2] : std::min(o_stride[0], o_stride[2])), //dO and O share the same stride in TE + scaling_factor, dropout_probability, + nullptr, nullptr, + set_ck_mask(mask_type, window_size_left, window_size_right), + window_size_left, window_size_right, + nullptr, + q_stride[1], (is_ragged? q_stride[2] : std::min(q_stride[0], q_stride[2])), //dq and q share the same stride in TE + nullptr, + nullptr, + nullptr, + dk_expanded_stride[1], (is_ragged? dk_expanded_stride[2] : std::min(dk_expanded_stride[0], dk_expanded_stride[2])), //dK and K share the same stride + dv_expanded_stride[1], (is_ragged? dv_expanded_stride[2] : std::min(dv_expanded_stride[0], dv_expanded_stride[2])), //dV and V share the same stride + nullptr, + k_stride[1], (is_ragged? k_stride[2] : std::min(k_stride[0], k_stride[2])), //dK and K share the same stride + nullptr, + v_stride[1], (is_ragged? v_stride[2] : std::min(v_stride[0], v_stride[2])), //dV and V share the same stride + nullptr, // softmax_lsed + deterministic, + nvte_ck_uses_bwd_v3, + nvte_ck_is_v3_atomic_fp32, + nvte_ck_how_v3_bf16_cvt, + true, //v3_api_check, TODO: remove later + stream)==1; +} + +__global__ +void cu_seqlen_to_seqlen_kernel( + uint64_t b, + const int32_t* cu_seqlen_q_ptr, const int32_t* cu_seqlen_kv_ptr, + int32_t* seqlen_q_ptr, int32_t* seqlen_kv_ptr){ + + for(int b_idx = blockIdx.x * blockDim.x + threadIdx.x; b_idx < b; b_idx += blockDim.x * gridDim.x){ + seqlen_q_ptr[b_idx] = cu_seqlen_q_ptr[b_idx + 1] - cu_seqlen_q_ptr[b_idx]; + seqlen_kv_ptr[b_idx] = cu_seqlen_kv_ptr[b_idx + 1] - cu_seqlen_kv_ptr[b_idx]; + } +} + +// kernel launcher to remove padding for softmax_lse +void cu_seqlen_to_seqlen( + uint64_t b, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + void* seqlen_q_ptr, void* seqlen_kv_ptr, + hipStream_t stream){ + + constexpr int THREADS_PER_BLOCK = 256; + dim3 block(THREADS_PER_BLOCK); + dim3 grid(ceil(1.0 * b /THREADS_PER_BLOCK)); + cu_seqlen_to_seqlen_kernel<<>>( + b, + static_cast(cu_seqlen_q_ptr), + static_cast(cu_seqlen_kv_ptr), + static_cast(seqlen_q_ptr), + static_cast(seqlen_kv_ptr)); +} void fused_attn_ck_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, bool pad_between_seqs, size_t max_tokens_q, size_t max_tokens_kv, @@ -847,6 +951,11 @@ void fused_attn_ck_bwd_impl( if (env_p != nullptr && std::string(env_p) == "1") nvte_log_ck_config = true; } + // bwd v3 is optional by enabling the following envs + // default values follows the ck example setting + bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 0); + bool nvte_ck_is_v3_atomic_fp32 = getenv("NVTE_CK_IS_V3_ATOMIC_FP32", 1); + int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); bool is_mqa_gqa = (h > hg); @@ -854,6 +963,9 @@ void fused_attn_ck_bwd_impl( size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool is_batch = (nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD || + nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD); + // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case size_t q_storage_bytes = max_tokens_q*h*d_qk*nvte_dtype_size(dtype); @@ -882,8 +994,22 @@ void fused_attn_ck_bwd_impl( if(pad_between_seqs){ // remove padding for the softmax_lse (*workspace_size)+= h*max_tokens_q*sizeof(float); - // allocate the q, k, v, o, do, dq, dk, dv, - (*workspace_size)+= 2*(q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes); + // TODO: remove v3 check after ck v2 fully support native padding + bool is_v3_supported = is_ck_attn_bwd_varlen_v3_supported(b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, max_tokens_kv, scaling_factor, dropout_probability, layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, dtype, nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, stream); + if(is_batch || is_ragged&&(!is_v3_supported)){ + // allocate the q, k, v, o, do, dq, dk, dv, + (*workspace_size)+= 2*(q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes); + if (nvte_log_ck_config) { + std::cout<(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + NVTE_CHECK_CUDA(cudaMemsetAsync(lse_workspace, 0, h*max_tokens_q*sizeof(float), stream)); // The next section are for dq_acc_ptr void* dq_acc_ptr = workspace_next; @@ -1004,93 +1131,105 @@ void fused_attn_ck_bwd_impl( void* devPtrSoftmaxLSEWithoutPadding = nullptr; - void* devPtrQWithoutPadding = nullptr; - void* devPtrKWithoutPadding = nullptr; - void* devPtrVWithoutPadding = nullptr; - void* devPtrOWithoutPadding = nullptr; - void* devPtrdOWithoutPadding = nullptr; - void* devPtrdQWithoutPadding = nullptr; - void* devPtrdKWithoutPadding = nullptr; - void* devPtrdVWithoutPadding = nullptr; + void* devPtrQWithoutPadding = devPtrQ; + void* devPtrKWithoutPadding = devPtrK; + void* devPtrVWithoutPadding = devPtrV; + void* devPtrOWithoutPadding = devPtrO; + void* devPtrdOWithoutPadding = devPtrdO; + void* devPtrdQWithoutPadding = devPtrdQ; + void* devPtrdKWithoutPadding = devPtrdK; + void* devPtrdVWithoutPadding = devPtrdV; + + // TODO: remove after ck v2 support cu_seqlen_padded ptrs + void* seqlen_q_ptr = nullptr; + void* seqlen_kv_ptr = nullptr; + //TODO: remove v3 api check after v2 fully support native padding + bool is_v3_supported = false; if(pad_between_seqs){ devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); - //determine q, k, v buffer based on the workspace next ptr and layout group - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); - //Q ptr always comes at first - devPtrQWithoutPadding = workspace_next; - if(layout_group==NVTE_QKV_Layout_Group::NVTE_3HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_H3D){ - //keep the start address difference the same among q, k, and v - devPtrKWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrK) - static_cast(devPtrQ))); - devPtrVWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrQ))); - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes + k_storage_bytes + v_storage_bytes); - }else if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); - //keep the start address difference the same between k and v - devPtrKWithoutPadding = workspace_next; - devPtrVWithoutPadding = static_cast(static_cast(devPtrKWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrK))); - workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes + v_storage_bytes); - }else{ - //qkv separated - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); - devPtrKWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes); - devPtrVWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + v_storage_bytes); - } - //determine the o, do buffer based on workspace next section - devPtrOWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); - devPtrdOWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); - //determine dq, dk, dv buffer based on the workspace next ptr and layout group - //dQ ptr always comes at first - devPtrdQWithoutPadding = workspace_next; - if(layout_group==NVTE_QKV_Layout_Group::NVTE_3HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_H3D){ - //keep the start address difference the same among q, k, and v - devPtrdKWithoutPadding = static_cast(static_cast(devPtrdQWithoutPadding) + (static_cast(devPtrK) - static_cast(devPtrQ))); - devPtrdVWithoutPadding = static_cast(static_cast(devPtrdQWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrQ))); - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes + k_storage_bytes + v_storage_bytes); - - // zeroing out the entire dqkv since packed - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes + k_storage_bytes+ v_storage_bytes, stream)); - }else if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); - //keep the start address difference the same between k and v - devPtrdKWithoutPadding = workspace_next; - devPtrdVWithoutPadding = static_cast(static_cast(devPtrdKWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrK))); - workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes + v_storage_bytes); - - // zeroing out just the dq itself - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes, stream)); + //TODO: remove v3 api check after v2 fully support native padding + is_v3_supported = is_ck_attn_bwd_varlen_v3_supported(b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, max_tokens_kv, scaling_factor, dropout_probability, layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, dtype, nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, stream); + if(is_batch || is_ragged&&(!is_v3_supported)){ + //determine q, k, v buffer based on the workspace next ptr and layout group + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + //Q ptr always comes at first + devPtrQWithoutPadding = workspace_next; + if(layout_group==NVTE_QKV_Layout_Group::NVTE_3HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_H3D){ + //keep the start address difference the same among q, k, and v + devPtrKWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrK) - static_cast(devPtrQ))); + devPtrVWithoutPadding = static_cast(static_cast(devPtrQWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrQ))); + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes + k_storage_bytes + v_storage_bytes); + }else if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); + //keep the start address difference the same between k and v + devPtrKWithoutPadding = workspace_next; + devPtrVWithoutPadding = static_cast(static_cast(devPtrKWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrK))); + workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes + v_storage_bytes); + }else{ + //qkv separated + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); + devPtrKWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes); + devPtrVWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + v_storage_bytes); + } + //determine the o, do buffer based on workspace next section + devPtrOWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); + devPtrdOWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); + + //determine dq, dk, dv buffer based on the workspace next ptr and layout group + //dQ ptr always comes at first + devPtrdQWithoutPadding = workspace_next; + if(layout_group==NVTE_QKV_Layout_Group::NVTE_3HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_H3D){ + //keep the start address difference the same among q, k, and v + devPtrdKWithoutPadding = static_cast(static_cast(devPtrdQWithoutPadding) + (static_cast(devPtrK) - static_cast(devPtrQ))); + devPtrdVWithoutPadding = static_cast(static_cast(devPtrdQWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrQ))); + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes + k_storage_bytes + v_storage_bytes); + + // zeroing out the entire dqkv since packed + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes + k_storage_bytes+ v_storage_bytes, stream)); + }else if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); + //keep the start address difference the same between k and v + devPtrdKWithoutPadding = workspace_next; + devPtrdVWithoutPadding = static_cast(static_cast(devPtrdKWithoutPadding) + (static_cast(devPtrV) - static_cast(devPtrK))); + workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes + v_storage_bytes); + + // zeroing out just the dq itself + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes, stream)); + }else{ + //qkv separated + workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); + devPtrdKWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes); + devPtrdVWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + v_storage_bytes); + + // zeroing out just the dq itself + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes, stream)); + } }else{ - //qkv separated - workspace_next = static_cast(static_cast(workspace_next) + q_storage_bytes); - devPtrdKWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + k_storage_bytes); - devPtrdVWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + v_storage_bytes); - - // zeroing out just the dq itself - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes, stream)); + seqlen_q_ptr = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + b*sizeof(int32_t)); + seqlen_kv_ptr = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + b*sizeof(int32_t)); } }else if(is_ragged){ devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); } - // bwd v3 is optional by enabling the following envs - // default values follows the ck example setting - bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 0); - bool nvte_ck_is_v3_atomic_fp32 = getenv("NVTE_CK_IS_V3_ATOMIC_FP32", 1); - int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); if (nvte_log_ck_config) { std::cout< Date: Tue, 21 Oct 2025 02:25:07 +0000 Subject: [PATCH 07/30] [ROCm] jax use runtime segment --- 3rdparty/aiter | 2 +- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 2 ++ .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 21 ++++++++++++++++--- .../common/fused_attn_rocm/fused_attn.cpp | 4 ++-- .../common/fused_attn_rocm/fused_attn_ck.cpp | 9 ++------ .../common/fused_attn_rocm/utils.cpp | 12 +++++------ .../jax/cpp_extensions/attention.py | 1 - .../jax/csrc/extensions/attention.cpp | 9 ++------ 8 files changed, 33 insertions(+), 27 deletions(-) diff --git a/3rdparty/aiter b/3rdparty/aiter index 74e71eb8e..963b86542 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit 74e71eb8ee8a663d5e33c0cfd8b4dad7708ae84b +Subproject commit 963b86542014fbed69c17aaa4fb46b6d69ab3b9c diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 2b717ace0..fc0ed3977 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -574,6 +574,7 @@ hipError_t ck_attn_bwd( dq_acc_ptr, //dq_acc_buf nullptr,//cu_seqlen_q nullptr,//cu_seqlen_kv + nullptr, /* seqlen_q_ptr */ nullptr, /* seqlen_k_ptr */ shape_seqlen_q, shape_seqlen_k, @@ -919,6 +920,7 @@ hipError_t ck_attn_varlen_bwd( dq_acc_ptr, //dq_acc_buf cu_seqlen_q_ptr,//cu_seqlen_q cu_seqlen_kv_ptr,//cu_seqlen_kv + nullptr, /* seqlen_q_ptr */ nullptr, /* seqlen_k_ptr */ max_seqlen_q, //seqlen_q, unused in group mode max_seqlen_k, //seqlen_kv, unused in group mode diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index fc2f14f06..6c5846844 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -40,25 +40,27 @@ void log_fwd_config(const char* func_name, std::cout<::type>(mask_type)<::type>(bias_type)<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<= len) return; + if (tid >= max_batch_size) return; - if (cu_seqlen[tid] > 0) { + if (cu_seqlen[tid+1] - cu_seqlen[tid] > 0) { // atomicAdd only support 32 bits dtype atomicAdd(out, 1); } @@ -253,15 +253,15 @@ void PopulateRngStateAsync(void *rng_state_dst, NVTE_CHECK_CUDA(cudaGetLastError()); } -uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) { +uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t max_batch_size, cudaStream_t stream) { // workspace size requires 4 bytes uint32_t *dout = static_cast(workspace); uint32_t hout{}; cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream); constexpr int threads = 128; - const int blocks = (len - 1) / threads + 1; + const int blocks = (max_batch_size - 1) / threads + 1; // ceil get_runtime_num_segments_kernel<<>>(static_cast(cu_seqlen), - len, dout); + max_batch_size, dout); cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); return hout; diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index d1f701489..2af7e4262 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -585,7 +585,6 @@ def convert_to_2d(offsets, batch, max_seqlen): q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) - output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( q, k, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index af1fcb493..dead65394 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -235,18 +235,13 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ size_t num_segments = input_batch; \ if (is_ragged) { \ - auto cudnn_runtime_version = cudnnGetVersion(); \ - if (cudnn_runtime_version >= 90300) { \ - num_segments = input_batch * max_segments_per_seq; \ - } else { \ size_t runtime_num_segments_q = nvte_get_runtime_num_segments( \ - q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \ + q_cu_seqlens, workspace, input_batch * max_segments_per_seq, stream); \ size_t runtime_num_segments_kv = nvte_get_runtime_num_segments( \ - kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \ + kv_cu_seqlens, workspace, input_batch * max_segments_per_seq, stream); \ NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \ NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \ num_segments = runtime_num_segments_q; \ - } \ } \ std::vector seq_shape{num_segments + 1}; \ auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32); \ From 579b592857109123ec85334a4ac3c93ce9f55184 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 22 Oct 2025 03:54:01 +0000 Subject: [PATCH 08/30] [ROCm] get runtime max_seqlen as well --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 7 ++++ .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 7 +++- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 35 +++++++++++++++++++ .../ck_fused_attn/src/ck_fused_attn_utils.hpp | 3 ++ 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index fc0ed3977..560e00abd 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -979,6 +979,13 @@ hipError_t ck_attn_varlen_bwd( std::pair{philox_seed_ptr, philox_offset_ptr}}; }(); + // modify the max_seqlen_q for better performance in 0-length cases + // lse_thd_ptr used as buffer + uint64_t runtime_max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); + uint64_t runtime_max_seqlen_kv = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); + fmha_args.max_seqlen_q = runtime_max_seqlen_q; + fmha_args.max_seqlen_k = runtime_max_seqlen_kv; + // print ck traits and args when needed log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); if (uses_bwd_v3) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 6c5846844..212f0d878 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -460,7 +460,12 @@ hipError_t ck_attn_varlen_fwd( false, std::pair{philox_seed_ptr, philox_offset_ptr}}; }(); - + // modify the max_seqlen_q for better performance in 0-length cases + // lse_thd_ptr used as buffer + if(!is_v3_api_check){ + uint64_t runtime_max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); + fmha_args.max_seqlen_q = runtime_max_seqlen_q; + } // print ck traits and args when needed log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args); if (uses_fwd_v3) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index ab9d28b2e..56618335f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -100,4 +100,39 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, return std::make_pair(bias_type, bias_shape); } +__global__ void get_runtime_max_seqlen_kernel( + uint64_t b, + const int32_t* cu_seqlen_ptr, + const int32_t* cu_seqlen_padded_ptr, + uint64_t *out) { + + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if(tid >= b){ + return; + } + if(cu_seqlen_padded_ptr){ + atomicMax(out, cu_seqlen_padded_ptr[tid+1] - cu_seqlen_padded_ptr[tid]); + }else{ + atomicMax(out, cu_seqlen_ptr[tid+1] - cu_seqlen_ptr[tid]); + } +} + +uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream){ + uint64_t* runtime_max_seqlen_ptr = static_cast(workspace); + uint64_t runtime_max_seqlen; + //reset the result buffer to 0 + hipMemsetAsync(runtime_max_seqlen_ptr, 0, sizeof(uint64_t), stream); + constexpr int threads = 128; + // in case b ==0 + const int blocks = (static_cast(b) - 1) / threads + 1; // ceil + get_runtime_max_seqlen_kernel<<>>( + b, + static_cast(cu_seqlen_ptr), + static_cast(cu_seqlen_padded_ptr), + runtime_max_seqlen_ptr); + hipMemcpyAsync(&runtime_max_seqlen, runtime_max_seqlen_ptr, sizeof(uint64_t), hipMemcpyDeviceToHost, stream); + hipStreamSynchronize(stream); + return runtime_max_seqlen; +} + }//namespace ck_fused_attn diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index b0ce02ee0..d248d65ca 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -9,6 +9,7 @@ #include #include +#include //forward declaration for ck_tile enum enum class mask_enum; @@ -54,5 +55,7 @@ BiasShape get_bias_shape(uint64_t b, uint64_t h, uint64_t bias_b, uint64_t bias_ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t b, uint64_t h, uint64_t bias_b, uint64_t bias_h); void set_aiter_asm_dir(); +uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H From 73247d9802e3e4a502e91e9bf7e735c243a7046b Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 22 Oct 2025 16:50:21 +0000 Subject: [PATCH 09/30] [ROCm] support v2 bwd native padding --- 3rdparty/aiter | 2 +- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 4 ++-- .../common/fused_attn_rocm/fused_attn_ck.cpp | 20 +++++++++---------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/3rdparty/aiter b/3rdparty/aiter index 963b86542..67ca00e87 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit 963b86542014fbed69c17aaa4fb46b6d69ab3b9c +Subproject commit 67ca00e873b7684c2817bff1f35f1fab92a22c5c diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index a7654e634..f6eb5ff20 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -961,8 +961,8 @@ hipError_t ck_attn_varlen_bwd( is_mqa_gqa? dv_expanded_ptr:dv_ptr, nullptr, dq_acc_ptr, //dq_acc_buf - cu_seqlen_q_ptr,//cu_seqlen_q - cu_seqlen_kv_ptr,//cu_seqlen_kv + cu_seqlen_q_padded_ptr? cu_seqlen_q_padded_ptr: cu_seqlen_q_ptr,//seqstart_q_ptr + cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr,//seqstart_k_ptr seqlen_q_ptr, /* seqlen_q_ptr */ seqlen_kv_ptr, /* seqlen_k_ptr */ max_seqlen_q, //seqlen_q, unused in group mode diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index dd80a6c99..53bae45eb 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -956,7 +956,7 @@ void fused_attn_ck_bwd_impl( bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 0); bool nvte_ck_is_v3_atomic_fp32 = getenv("NVTE_CK_IS_V3_ATOMIC_FP32", 1); int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); - + bool is_mqa_gqa = (h > hg); size_t kN0 = (d_qk <= 128)? 128:64; @@ -996,7 +996,7 @@ void fused_attn_ck_bwd_impl( (*workspace_size)+= h*max_tokens_q*sizeof(float); // TODO: remove v3 check after ck v2 fully support native padding bool is_v3_supported = is_ck_attn_bwd_varlen_v3_supported(b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, max_tokens_kv, scaling_factor, dropout_probability, layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, dtype, nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, stream); - if(is_batch || is_ragged&&(!is_v3_supported)){ + if(is_batch || is_ragged&&is_v3_supported){ // allocate the q, k, v, o, do, dq, dk, dv, (*workspace_size)+= 2*(q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes); if (nvte_log_ck_config) { @@ -1007,7 +1007,7 @@ void fused_attn_ck_bwd_impl( // allocate the seqlen_padded's ptr (*workspace_size)+= 2*b*sizeof(int32_t); if (nvte_log_ck_config) { - std::cout< Date: Wed, 22 Oct 2025 12:52:05 -0500 Subject: [PATCH 10/30] Updated conversion to include bwd pass --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 63 +++++++++++++++---- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7188c4df0..78597566a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -216,7 +216,7 @@ ck_fused_attn::MaskType set_ck_mask(NVTE_Mask_Type nvte_mask_type, int64_t nvte_ } __global__ -void generate_cu_seqlen_padded( +void generate_cu_seqlen_padded_kernel( uint32_t s_q, uint32_t s_kv, uint32_t b, int32_t* cu_seqlen_q_padded_ptr, int32_t* cu_seqlen_kv_padded_ptr @@ -227,6 +227,22 @@ void generate_cu_seqlen_padded( } } +void generate_cu_seqlen_padded( + uint32_t s_q, uint32_t s_kv, uint32_t b, + void* cu_seqlen_q_padded_ptr, + void* cu_seqlen_kv_padded_ptr, + hipStream_t stream +){ + constexpr int THREADS_PER_BLOCK = 256; + dim3 block(THREADS_PER_BLOCK); + dim3 grid(ceil(1.0 * (b+1)/THREADS_PER_BLOCK)); + generate_cu_seqlen_padded_kernel<<>>( + s_q, s_kv, b, + static_cast(cu_seqlen_q_padded_ptr), + static_cast(cu_seqlen_kv_padded_ptr) + ); +} + __global__ void generate_alibi_slope(uint64_t h, float* alibi_slope_ptr){ for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < h; id += blockDim.x * gridDim.x){ @@ -728,16 +744,9 @@ void fused_attn_ck_fwd_impl( } // If input is BSHD, we may directly convert to THD if(is_BSHD && is_padding){ - constexpr int THREADS_PER_BLOCK = 256; - dim3 block(THREADS_PER_BLOCK); - dim3 grid(ceil(1.0 * (b+1)/THREADS_PER_BLOCK)); - generate_cu_seqlen_padded<<>>( - s_q, s_kv, b, - static_cast(devPtrSeqOffsetsQ), - static_cast(devPtrSeqOffsetsKV) - ); + generate_cu_seqlen_padded(s_q, s_kv, b, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, stream); if(nvte_log_ck_config){ - std::cout << "\nConverting BSHD to THD\n"; + std::cout << "\nattn_fwd(ck): Converting BSHD to THD\n"; } } if(is_SBHD && is_padding){ @@ -893,7 +902,14 @@ void fused_attn_ck_bwd_impl( size_t kN0 = (d_qk <= 128)? 128:64; size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; - bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool is_SBHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD; + bool is_BSHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD; + bool is_batch = is_BSHD || is_SBHD; + + bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case size_t q_storage_bytes = max_tokens_q*h*d_qk*nvte_dtype_size(dtype); @@ -919,7 +935,13 @@ void fused_attn_ck_bwd_impl( //ck requires a buffer dbias_expanded of size BHSS if bias is not BHSS (*workspace_size) += b*h*s_q*s_kv*nvte_dtype_size(dtype); } + // TODO(micky774): Avoid workaround when native padding/unpadding support is + // available if(pad_between_seqs){ + if(is_BSHD && is_padding){ + // cu_seqlen_padded buffers + (*workspace_size)+= 2*(b+1)*sizeof(int32_t); + } // remove padding for the softmax_lse (*workspace_size)+= h*max_tokens_q*sizeof(float); // allocate the q, k, v, o, do, dq, dk, dv, @@ -1053,9 +1075,12 @@ void fused_attn_ck_bwd_impl( void* devPtrdKWithoutPadding = nullptr; void* devPtrdVWithoutPadding = nullptr; + // TODO(micky774): Avoid workaround when native padding/unpadding support is + // available if(pad_between_seqs){ devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + //determine q, k, v buffer based on the workspace next ptr and layout group NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); //Q ptr always comes at first @@ -1116,6 +1141,13 @@ void fused_attn_ck_bwd_impl( // zeroing out just the dq itself NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes, stream)); } + if(is_BSHD && is_padding){ + // cu_seqlen_padded ptrs for THD conversion + devPtrSeqOffsetsQ = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + devPtrSeqOffsetsKV = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + } }else if(is_ragged){ devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); @@ -1160,6 +1192,15 @@ void fused_attn_ck_bwd_impl( std::cout<<"nvte_ck_is_v3_atomic_fp32: "< Date: Thu, 23 Oct 2025 15:21:59 -0500 Subject: [PATCH 11/30] Added BWD BSHD-->THD conversion and minor logic refactor --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 0857c1626..565c06545 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -868,7 +868,7 @@ bool is_ck_attn_bwd_varlen_v3_supported( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, size_t max_tokens_q, size_t max_tokens_kv, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, + NVTE_QKV_Layout layout, bool is_ragged, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, @@ -879,7 +879,6 @@ bool is_ck_attn_bwd_varlen_v3_supported( cudaStream_t stream){ bool is_mqa_gqa = (h > hg); - bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; std::array q_stride; std::array k_stride; @@ -1011,10 +1010,15 @@ void fused_attn_ck_bwd_impl( size_t kN0 = (d_qk <= 128)? 128:64; size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; - bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; - bool is_batch = (nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD || - nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD); - + bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool is_SBHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD; + bool is_BSHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD; + bool is_batch = is_BSHD || is_SBHD; + bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool bshd_to_thd = is_BSHD && is_padding; + // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case size_t q_storage_bytes = max_tokens_q*h*d_qk*nvte_dtype_size(dtype); @@ -1022,6 +1026,12 @@ void fused_attn_ck_bwd_impl( size_t v_storage_bytes = max_tokens_kv*hg*d_v*nvte_dtype_size(dtype); size_t o_storage_bytes = max_tokens_q*h*d_v*nvte_dtype_size(dtype); + bool is_v3_supported = false; + bool needs_padding_conversion = ( + is_SBHD || + (bshd_to_thd && is_v3_supported) || + (is_ragged&&is_v3_supported) + ); // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ size_t workspace_size_lse = max_tokens_q*h*sizeof(float); @@ -1043,15 +1053,15 @@ void fused_attn_ck_bwd_impl( // TODO(micky774): Avoid workaround when native padding/unpadding support is // available if(pad_between_seqs){ - if(is_BSHD && is_padding){ + if(bshd_to_thd){ // cu_seqlen_padded buffers (*workspace_size)+= 2*(b+1)*sizeof(int32_t); } + // TODO: remove v3 check after ck v2 fully support native padding + is_v3_supported = is_ck_attn_bwd_varlen_v3_supported(b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, max_tokens_kv, scaling_factor, dropout_probability, layout, is_ragged || bshd_to_thd, bias_type, mask_type, window_size_left, window_size_right, deterministic, dtype, nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, stream); // remove padding for the softmax_lse (*workspace_size)+= h*max_tokens_q*sizeof(float); - // TODO: remove v3 check after ck v2 fully support native padding - bool is_v3_supported = is_ck_attn_bwd_varlen_v3_supported(b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, max_tokens_kv, scaling_factor, dropout_probability, layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, dtype, nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, stream); - if(is_batch || is_ragged&&is_v3_supported){ + if(needs_padding_conversion){ // allocate the q, k, v, o, do, dq, dk, dv, (*workspace_size)+= 2*(q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes); if (nvte_log_ck_config) { @@ -1199,15 +1209,11 @@ void fused_attn_ck_bwd_impl( void* seqlen_q_ptr = nullptr; void* seqlen_kv_ptr = nullptr; - //TODO: remove v3 api check after v2 fully support native padding - bool is_v3_supported = false; if(pad_between_seqs){ devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); - //TODO: remove v3 api check after v2 fully support native padding - is_v3_supported = is_ck_attn_bwd_varlen_v3_supported(b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, max_tokens_kv, scaling_factor, dropout_probability, layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, dtype, nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, stream); - if(is_batch || is_ragged&&is_v3_supported){ + if(needs_padding_conversion){ //determine q, k, v buffer based on the workspace next ptr and layout group NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); //Q ptr always comes at first @@ -1274,7 +1280,7 @@ void fused_attn_ck_bwd_impl( seqlen_kv_ptr = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + b*sizeof(int32_t)); } - if(is_BSHD && is_padding){ + if(bshd_to_thd){ // cu_seqlen_padded ptrs for THD conversion devPtrSeqOffsetsQ = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); @@ -1322,7 +1328,7 @@ void fused_attn_ck_bwd_impl( std::cout<<"nvte_ck_how_v3_bf16_cvt: "< Date: Thu, 23 Oct 2025 16:26:26 -0500 Subject: [PATCH 12/30] Corrected softmax lse bug --- transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 565c06545..6dfe2f800 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -834,7 +834,7 @@ void fused_attn_ck_fwd_impl( false, stream)); // aiter asm output softmax_lse with padding - add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); + add_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else{ using ck_fused_attn::ck_attn_fwd; NVTE_CHECK_CUDA( @@ -1030,7 +1030,7 @@ void fused_attn_ck_bwd_impl( bool needs_padding_conversion = ( is_SBHD || (bshd_to_thd && is_v3_supported) || - (is_ragged&&is_v3_supported) + (is_ragged && is_v3_supported) ); // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ From 5c2418874cb29aeebca1e0ade032da66da21017a Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 23 Oct 2025 17:23:20 -0500 Subject: [PATCH 13/30] Updated logic flow and re-caclulation --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 6dfe2f800..169d3a86c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -596,6 +596,7 @@ void fused_attn_ck_fwd_impl( bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool bshd_to_thd = is_BSHD && is_padding; // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case @@ -743,7 +744,7 @@ void fused_attn_ck_fwd_impl( std::cout<<"nvte_ck_uses_fwd_v3: "<(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + is_v3_supported = is_ck_attn_bwd_varlen_v3_supported(b, h, hg, s_q, s_kv, d_qk, d_v, max_tokens_q, max_tokens_kv, scaling_factor, dropout_probability, layout, is_ragged || bshd_to_thd, bias_type, mask_type, window_size_left, window_size_right, deterministic, dtype, nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, stream); + needs_padding_conversion = ( + (is_SBHD && is_padding) || + (bshd_to_thd && is_v3_supported) || + (is_ragged && is_v3_supported) + ); if(needs_padding_conversion){ //determine q, k, v buffer based on the workspace next ptr and layout group From b59d466c1d57c7c432fc2a7c31f5cc28620f2943 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 20 Oct 2025 16:50:41 +0000 Subject: [PATCH 14/30] [ROCm] manually pick Meekail's PR to support native padding for bwd [ROCm] support v2 bwd native padding --- 3rdparty/aiter | 2 +- .../include/ck_fused_attn/ck_fused_attn.hpp | 2 + .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 138 ++++++++++++----- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 27 ++-- .../common/fused_attn_rocm/fused_attn_ck.cpp | 140 ++++++++++-------- 5 files changed, 193 insertions(+), 116 deletions(-) diff --git a/3rdparty/aiter b/3rdparty/aiter index 963b86542..cad028e1d 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit 963b86542014fbed69c17aaa4fb46b6d69ab3b9c +Subproject commit cad028e1d10efdbc179fa1e42b6c65f204e168cd diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 1eef86309..dddb81adb 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -141,6 +141,7 @@ hipError_t ck_attn_varlen_bwd( const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, const void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, const void* lse_thd_ptr, @@ -166,6 +167,7 @@ hipError_t ck_attn_varlen_bwd( bool uses_bwd_v3, bool is_v3_atomic_fp32, int how_v3_bf16_cvt, + bool is_v3_api_check, hipStream_t stream); }//namespace ck_fused_attn diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 560e00abd..704d71c0d 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -15,6 +15,24 @@ namespace ck_fused_attn{ +// TODO: unify with binary search in TE/common/fused_attn(rocm)/util +// no device std::upper_bound +// in an increasing array with given size len, search for the index that: +// array[index] <= target < array[index+1] +// guaranteed that target >=0 and target <= cu_seqlen[end-1] +__forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + // define dk_dv_reduce function only for fp16 and bf16 types template __global__ void dk_dv_reduce( @@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce( // define dk_dv_reduce function in THD layout only for fp16 and bf16 types template __global__ void dk_dv_reduce_thd( - uint64_t h, uint64_t hg, uint64_t d, - const int32_t* total_seqlen_kv_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t d, + const int32_t* cu_seqlen_kv_ptr, + const int32_t* cu_seqlen_kv_padded_ptr, const DataType *dk_expanded, const DataType *dv_expanded, uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded, @@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd( uint64_t hdim_idx = threadIdx.x; assert(hdim_idx= *total_seqlen_kv_ptr){ + + if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } - + if(cu_seqlen_kv_padded_ptr){ + uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); + uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; + if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ + return; + } + } // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd( // When d_qk != d_v, we need to reduce dk and dv separately template __global__ void dk_or_dv_reduce_thd( - uint64_t h, uint64_t hg, uint64_t d, - const int32_t* total_seqlen_kv_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t d, + const int32_t* cu_seqlen_kv_ptr, + const int32_t* cu_seqlen_kv_padded_ptr, const DataType *dk_or_dv_expanded, uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded, DataType *dk_or_dv, @@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd( assert(hdim_idx= *total_seqlen_kv_ptr){ + if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } - + if(cu_seqlen_kv_padded_ptr){ + uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); + uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; + if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ + return; + } + } // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -312,6 +344,8 @@ void log_bwd_config(const char* func_name, const bool uses_bwd_v3, const bool is_v3_atomic_fp32, const int how_v3_bf16_cvt, + const void* cu_seqlen_q_padded_ptr, + const void* cu_seqlen_kv_padded_ptr, const fmha_bwd_args& fmha_args){ bool ck_fused_attn_log_config = false; @@ -337,6 +371,8 @@ void log_bwd_config(const char* func_name, std::cout<<"uses_bwd_v3: "< 0.f); @@ -916,12 +965,14 @@ hipError_t ck_attn_varlen_bwd( dq_ptr, is_mqa_gqa? dk_expanded_ptr:dk_ptr, is_mqa_gqa? dv_expanded_ptr:dv_ptr, - nullptr, + nullptr, //dbias_ptr dq_acc_ptr, //dq_acc_buf - cu_seqlen_q_ptr,//cu_seqlen_q - cu_seqlen_kv_ptr,//cu_seqlen_kv + cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr + cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr nullptr, /* seqlen_q_ptr */ nullptr, /* seqlen_k_ptr */ + cu_seqlen_q_ptr, //cu_seqlen_q_ptr + cu_seqlen_kv_ptr, //cu_seqlen_k_ptr max_seqlen_q, //seqlen_q, unused in group mode max_seqlen_k, //seqlen_kv, unused in group mode batch, @@ -987,25 +1038,31 @@ hipError_t ck_attn_varlen_bwd( fmha_args.max_seqlen_k = runtime_max_seqlen_kv; // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args); if (uses_bwd_v3) { set_aiter_asm_dir(); } - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_enum::no_bias, - has_dbias, - s_randval, - deterministic, - uses_bwd_v3, - is_v3_atomic_fp32, - how_v3_bf16_cvt); - if(average_runtime < 0){ + float average_runtime_or_v3_check_status = aiter::mha_bwd(fmha_args, + stream_config, + data_type_str, + is_group_mode, + mask_type, + bias_enum::no_bias, + has_dbias, + s_randval, + deterministic, + uses_bwd_v3, + is_v3_atomic_fp32, + how_v3_bf16_cvt, + uses_bwd_v3? cu_seqlen_q_padded_ptr: nullptr, + uses_bwd_v3? cu_seqlen_kv_padded_ptr: nullptr, + is_v3_api_check); + if(is_v3_api_check){ + return (hipError_t)(average_runtime_or_v3_check_status > 0); + } + if(average_runtime_or_v3_check_status < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); } @@ -1015,6 +1072,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block(d_qk); if (ck_fused_attn_log_config){ std::cout<, grid, block, 0, stream, - h, hg, d_qk, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_qk, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dk_expanded_ptr), static_cast(dv_expanded_ptr), stride_h_dk_expanded, stride_s_dk_expanded, @@ -1039,6 +1099,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block_dk(d_qk); if (ck_fused_attn_log_config){ std::cout<, grid, block_dk, 0, stream, - h, hg, d_qk, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_qk, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dk_expanded_ptr), stride_h_dk_expanded, stride_s_dk_expanded, static_cast(dk_ptr), @@ -1059,6 +1122,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block_dv(d_v); if (ck_fused_attn_log_config){ std::cout<, grid, block_dv, 0, stream, - h, hg, d_v, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_v, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dv_expanded_ptr), stride_h_dv_expanded, stride_s_dv_expanded, static_cast(dv_ptr), diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 212f0d878..a1d26b0f8 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -69,15 +69,12 @@ void log_fwd_config(const char* func_name, std::cout<<"lse_ptr: "<data.dptr!=input_cu_seqlens_padded->data.dptr - && !input_cu_seqlens_padded->data.shape.empty() - ); - // Next we guard against an initial workspace-allocation which occurs in the - // JAX TE extension. We check for both pointers being null while retaining - // shape data, indicating the use of dummy data in the allocation pass. - pad_between_seqs = pad_between_seqs || ( - is_ragged - && input_cu_seqlens->data.dptr==nullptr && !input_cu_seqlens->data.shape.empty() - && input_cu_seqlens_padded->data.dptr==nullptr && !input_cu_seqlens_padded->data.shape.empty() - ); - // Finally we check whether we have an array with padding and non-empty input_cu_seqlens - pad_between_seqs = pad_between_seqs || ( - !is_ragged - && is_padding - && !input_cu_seqlens->data.shape.empty() - ); - return pad_between_seqs; -} // check the fused attn config to see whether it's ck backend supported // single filtering followed by joint filtering bool is_ck_backend_supported( @@ -737,7 +709,6 @@ void fused_attn_ck_fwd_impl( add_padding_softmax_lse(b, h, s_q, max_tokens_q, false, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else if(is_ragged){ using ck_fused_attn::ck_attn_varlen_fwd; - // TODO: remove the v3 api check after ck align softmax_lse with aiter asm bool is_v3_supported = ck_attn_varlen_fwd( nvte_to_ck_dtype(dtype), b, h, hg, s_q, s_kv, d_qk, d_v, @@ -786,7 +757,7 @@ void fused_attn_ck_fwd_impl( false, stream)); // aiter asm output softmax_lse with padding - add_padding_softmax_lse(b, h, s_q, max_tokens_q, true, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); + add_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxLSEWithoutPadding, devPtrSeqOffsetsQ, devPtrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else{ using ck_fused_attn::ck_attn_fwd; NVTE_CHECK_CUDA( @@ -817,7 +788,7 @@ void fused_attn_ck_fwd_impl( void fused_attn_ck_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - bool pad_between_seqs, size_t max_tokens_q, size_t max_tokens_kv, + size_t max_tokens_q, size_t max_tokens_kv, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -842,13 +813,19 @@ void fused_attn_ck_bwd_impl( if (env_p != nullptr && std::string(env_p) == "1") nvte_log_ck_config = true; } - + bool is_mqa_gqa = (h > hg); size_t kN0 = (d_qk <= 128)? 128:64; size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool is_batch = (nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD || + nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD); + bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case size_t q_storage_bytes = max_tokens_q*h*d_qk*nvte_dtype_size(dtype); @@ -874,13 +851,13 @@ void fused_attn_ck_bwd_impl( //ck requires a buffer dbias_expanded of size BHSS if bias is not BHSS (*workspace_size) += b*h*s_q*s_kv*nvte_dtype_size(dtype); } - if(pad_between_seqs){ + if(is_batch && is_padding){ // remove padding for the softmax_lse (*workspace_size)+= h*max_tokens_q*sizeof(float); // allocate the q, k, v, o, do, dq, dk, dv, (*workspace_size)+= 2*(q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes); }else if(is_ragged){ - // remove padding for the softmax_lse + // convert softmax lse from te format to ck/aiter format (*workspace_size)+= h*max_tokens_q*sizeof(float); } if (nvte_log_ck_config) { @@ -918,8 +895,9 @@ void fused_attn_ck_bwd_impl( }else{ // HD_2HD, HD_H2D, HD_HD_HD can just memset dq itself NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream)); + //TODO: check whether we can remove this zero-out in padding/unpadding workaround // for pad between seqs case, we need to reset all dq, dk, dv - if(pad_between_seqs){ + if(is_batch && is_padding){ if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ //kvpacked NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes + v_storage_bytes, stream)); @@ -1007,10 +985,11 @@ void fused_attn_ck_bwd_impl( void* devPtrdQWithoutPadding = nullptr; void* devPtrdKWithoutPadding = nullptr; void* devPtrdVWithoutPadding = nullptr; - - if(pad_between_seqs){ + + if(is_batch && is_padding){ devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + //determine q, k, v buffer based on the workspace next ptr and layout group NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); //Q ptr always comes at first @@ -1086,6 +1065,7 @@ void fused_attn_ck_bwd_impl( std::cout<<"layout: "<data).shape.begin(), (input_QKV->data).shape.end(), static_cast(1), std::multiplies())/h/d/3; @@ -1474,7 +1492,7 @@ void fused_attn_ck_bwd_qkvpacked( fused_attn_ck_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, d, bias_b, bias_h, - pad_between_seqs, max_tokens, max_tokens, + max_tokens, max_tokens, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, @@ -1616,7 +1634,6 @@ void fused_attn_ck_fwd_kvpacked( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, bias_b, bias_h, @@ -1717,7 +1734,6 @@ void fused_attn_ck_bwd_kvpacked( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); // extract the max_tokens for padding/unpadding and softmax_lse buffer // b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case @@ -1726,7 +1742,7 @@ void fused_attn_ck_bwd_kvpacked( fused_attn_ck_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, bias_b, bias_h, - pad_between_seqs, max_tokens_q, max_tokens_kv, + max_tokens_q, max_tokens_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, @@ -1858,7 +1874,6 @@ void fused_attn_ck_fwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, @@ -1948,7 +1963,6 @@ void fused_attn_ck_bwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); // extract the max_tokens for padding/unpadding and softmax_lse buffer // b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case @@ -1957,7 +1971,7 @@ void fused_attn_ck_bwd( fused_attn_ck_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, - pad_between_seqs, max_tokens_q, max_tokens_kv, + max_tokens_q, max_tokens_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, From f27a99fc81833bb1953460aab15adf5e0444bdd6 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Oct 2025 18:08:04 +0000 Subject: [PATCH 15/30] Added env var guard --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 13 +++++++++---- .../common/ck_fused_attn/src/ck_fused_attn_fwd.cpp | 10 +++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 704d71c0d..03f775d34 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -1032,10 +1032,15 @@ hipError_t ck_attn_varlen_bwd( // modify the max_seqlen_q for better performance in 0-length cases // lse_thd_ptr used as buffer - uint64_t runtime_max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); - uint64_t runtime_max_seqlen_kv = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); - fmha_args.max_seqlen_q = runtime_max_seqlen_q; - fmha_args.max_seqlen_k = runtime_max_seqlen_kv; + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { + if(std::string(env_p) == "1" && !is_v3_api_check){ + if(ck_fused_attn_log_config){ + std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); + fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); + } + } // print ck traits and args when needed log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index a1d26b0f8..e95032204 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -457,9 +457,13 @@ hipError_t ck_attn_varlen_fwd( }(); // modify the max_seqlen_q for better performance in 0-length cases // lse_thd_ptr used as buffer - if(!is_v3_api_check){ - uint64_t runtime_max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); - fmha_args.max_seqlen_q = runtime_max_seqlen_q; + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ + if(std::string(env_p) == "1" && !is_v3_api_check){ + if(ck_fused_attn_log_config){ + std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); + } } // print ck traits and args when needed log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args); From 33c59126485987cddfb325159769f8d30faadffe Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Oct 2025 15:04:50 +0000 Subject: [PATCH 16/30] Updated ptr variables and streamlined dispatch --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 102 +++++++++--------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index eb873daa8..81886f4a8 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -629,6 +629,8 @@ void fused_attn_ck_fwd_impl( void* devPtrKWithoutPadding = nullptr; void* devPtrVWithoutPadding = nullptr; void* devPtrOWithoutPadding = nullptr; + void* ptrSeqOffsetsQ = devPtrSeqOffsetsQ; + void* ptrSeqOffsetsKV = devPtrSeqOffsetsKV; // next h*max_tokens_q*sizeof(float) in workspace are for lse buffer @@ -660,10 +662,14 @@ void fused_attn_ck_fwd_impl( } }else if(bshd_to_thd){ // cu_seqlen_padded ptrs for THD conversion - devPtrSeqOffsetsQ = workspace_next; + ptrSeqOffsetsQ = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); - devPtrSeqOffsetsKV = workspace_next; + ptrSeqOffsetsKV = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + generate_cu_seqlen_padded(s_q, s_kv, b, ptrSeqOffsetsQ, ptrSeqOffsetsKV, stream); + if(nvte_log_ck_config){ + std::cout << "\nattn_fwd(ck): Converting BSHD to THD\n"; + } } //determine the o buffer based on workspace next section devPtrOWithoutPadding = workspace_next; @@ -676,7 +682,10 @@ void fused_attn_ck_fwd_impl( std::cout<<"layout: "<(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); @@ -1102,10 +1104,14 @@ void fused_attn_ck_bwd_impl( } }else if(bshd_to_thd){ // cu_seqlen_padded ptrs for THD conversion - devPtrSeqOffsetsQ = workspace_next; + ptrSeqOffsetsQ = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); - devPtrSeqOffsetsKV = workspace_next; + ptrSeqOffsetsKV = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + generate_cu_seqlen_padded(s_q, s_kv, b, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, stream); + if(nvte_log_ck_config){ + std::cout << "\nattn_bwd(ck): Converting BSHD to THD\n"; + } } @@ -1116,6 +1122,8 @@ void fused_attn_ck_bwd_impl( std::cout<<"max_tokens_kv: "< Date: Wed, 29 Oct 2025 15:21:07 +0000 Subject: [PATCH 17/30] Added env guard --- transformer_engine/jax/cpp_extensions/attention.py | 1 + transformer_engine/jax/csrc/extensions/attention.cpp | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 2af7e4262..d1f701489 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -585,6 +585,7 @@ def convert_to_2d(offsets, batch, max_seqlen): q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) + output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( q, k, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index dead65394..43f915871 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -234,7 +234,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ size_t num_segments = input_batch; \ - if (is_ragged) { \ + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_NUM_SEGMENTS")){ \ + if(std::string(env_p) == "1" && is_ragged){ \ size_t runtime_num_segments_q = nvte_get_runtime_num_segments( \ q_cu_seqlens, workspace, input_batch * max_segments_per_seq, stream); \ size_t runtime_num_segments_kv = nvte_get_runtime_num_segments( \ @@ -242,6 +243,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \ NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \ num_segments = runtime_num_segments_q; \ + } \ } \ std::vector seq_shape{num_segments + 1}; \ auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32); \ From bc8f4a7285fe3251581f3539b6ce7002b48f0b58 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Oct 2025 19:08:25 +0000 Subject: [PATCH 18/30] Corrected bshd_to_thd conversion arguments --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 81886f4a8..26c7fbf8a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1108,7 +1108,7 @@ void fused_attn_ck_bwd_impl( workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); ptrSeqOffsetsKV = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); - generate_cu_seqlen_padded(s_q, s_kv, b, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, stream); + generate_cu_seqlen_padded(s_q, s_kv, b, ptrSeqOffsetsQ, ptrSeqOffsetsKV, stream); if(nvte_log_ck_config){ std::cout << "\nattn_bwd(ck): Converting BSHD to THD\n"; } @@ -1153,14 +1153,14 @@ void fused_attn_ck_bwd_impl( } if(is_SBHD && is_padding){ // remove padding for q, k, v, o, do - remove_padding(dtype, b, h, s_q, d_qk, max_tokens_q, is_ragged, q_stride[0], q_stride[1], q_stride[2], devPtrQ, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrQWithoutPadding, stream); - remove_padding(dtype, b, hg, s_kv, d_qk, max_tokens_kv, is_ragged, k_stride[0], k_stride[1], k_stride[2], devPtrK, devPtrCuSeqlensKV, ptrSeqOffsetsKV, devPtrKWithoutPadding, stream); - remove_padding(dtype, b, hg, s_kv, d_v, max_tokens_kv, is_ragged, v_stride[0], v_stride[1], v_stride[2], devPtrV, devPtrCuSeqlensKV, ptrSeqOffsetsKV, devPtrVWithoutPadding, stream); + remove_padding(dtype, b, h, s_q, d_qk, max_tokens_q, is_ragged, q_stride[0], q_stride[1], q_stride[2], devPtrQ, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrQWithoutPadding, stream); + remove_padding(dtype, b, hg, s_kv, d_qk, max_tokens_kv, is_ragged, k_stride[0], k_stride[1], k_stride[2], devPtrK, devPtrCuSeqlensKV, devPtrSeqOffsetsKV, devPtrKWithoutPadding, stream); + remove_padding(dtype, b, hg, s_kv, d_v, max_tokens_kv, is_ragged, v_stride[0], v_stride[1], v_stride[2], devPtrV, devPtrCuSeqlensKV, devPtrSeqOffsetsKV, devPtrVWithoutPadding, stream); // o and do should be of same shape as q - remove_padding(dtype, b, h, s_q, d_v, max_tokens_q, is_ragged, o_stride[0], o_stride[1], o_stride[2], devPtrO, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrOWithoutPadding, stream); - remove_padding(dtype, b, h, s_q, d_v, max_tokens_q, is_ragged, o_stride[0], o_stride[1], o_stride[2], devPtrdO, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrdOWithoutPadding, stream); + remove_padding(dtype, b, h, s_q, d_v, max_tokens_q, is_ragged, o_stride[0], o_stride[1], o_stride[2], devPtrO, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrOWithoutPadding, stream); + remove_padding(dtype, b, h, s_q, d_v, max_tokens_q, is_ragged, o_stride[0], o_stride[1], o_stride[2], devPtrdO, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrdOWithoutPadding, stream); // also remove the padding for softmax lse - remove_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxAux, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrSoftmaxLSEWithoutPadding, stream); + remove_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxAux, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrSoftmaxLSEWithoutPadding, stream); using ck_fused_attn::ck_attn_varlen_bwd; NVTE_CHECK_CUDA( ck_attn_varlen_bwd( @@ -1204,9 +1204,9 @@ void fused_attn_ck_bwd_impl( stream)); // add padding for dq, dk, dv // dq, dk, dv of same shape as q, k, v - add_padding(dtype, b, h, s_q, d_qk, max_tokens_q, is_ragged, q_stride[0], q_stride[1], q_stride[2], devPtrdQWithoutPadding, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrdQ, stream); - add_padding(dtype, b, hg, s_kv, d_qk, max_tokens_kv, is_ragged, k_stride[0], k_stride[1], k_stride[2], devPtrdKWithoutPadding, devPtrCuSeqlensKV, ptrSeqOffsetsKV, devPtrdK, stream); - add_padding(dtype, b, hg, s_kv, d_v, max_tokens_kv, is_ragged, v_stride[0], v_stride[1], v_stride[2], devPtrdVWithoutPadding, devPtrCuSeqlensKV, ptrSeqOffsetsKV, devPtrdV, stream); + add_padding(dtype, b, h, s_q, d_qk, max_tokens_q, is_ragged, q_stride[0], q_stride[1], q_stride[2], devPtrdQWithoutPadding, devPtrCuSeqlensQ, devPtrSeqOffsetsQ, devPtrdQ, stream); + add_padding(dtype, b, hg, s_kv, d_qk, max_tokens_kv, is_ragged, k_stride[0], k_stride[1], k_stride[2], devPtrdKWithoutPadding, devPtrCuSeqlensKV, devPtrSeqOffsetsKV, devPtrdK, stream); + add_padding(dtype, b, hg, s_kv, d_v, max_tokens_kv, is_ragged, v_stride[0], v_stride[1], v_stride[2], devPtrdVWithoutPadding, devPtrCuSeqlensKV, devPtrSeqOffsetsKV, devPtrdV, stream); }else if(bshd_to_thd || is_ragged){ remove_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxAux, ptrSeqOffsetsQ, ptrSeqOffsetsQ, devPtrSoftmaxLSEWithoutPadding, stream); using ck_fused_attn::ck_attn_varlen_bwd; From b7f2cf8afef0db2786351d7cd55e58d643fdc7d3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 30 Oct 2025 18:27:54 +0000 Subject: [PATCH 19/30] Corrected logical flow --- transformer_engine/jax/csrc/extensions/attention.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 43f915871..c13c704de 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -234,8 +234,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ size_t num_segments = input_batch; \ - if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_NUM_SEGMENTS")){ \ - if(std::string(env_p) == "1" && is_ragged){ \ + if(is_ragged){ \ + auto cudnn_runtime_version = cudnnGetVersion(); \ + num_segments = input_batch * max_segments_per_seq; \ + bool use_runtime_num_segments_check = false; \ + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_NUM_SEGMENTS")){ \ + use_runtime_num_segments_check = std::string(env_p) == "1"; \ + } \ + if(cudnn_runtime_version < 90300 || use_runtime_num_segments_check){ \ size_t runtime_num_segments_q = nvte_get_runtime_num_segments( \ q_cu_seqlens, workspace, input_batch * max_segments_per_seq, stream); \ size_t runtime_num_segments_kv = nvte_get_runtime_num_segments( \ From 3e48a029dcd6f291b10e91c8828779f3399315a4 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 5 Nov 2025 13:03:19 -0600 Subject: [PATCH 20/30] Guarded memset and corrected allocation --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 26c7fbf8a..0c907f3d5 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -637,6 +637,10 @@ void fused_attn_ck_fwd_impl( devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); if(is_SBHD && is_padding){ + //determine the o buffer based on workspace next section + devPtrOWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); + //determine q, k ,v buffer based on the workspace next ptr and layout group NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); //Q ptr always comes at first @@ -671,11 +675,10 @@ void fused_attn_ck_fwd_impl( std::cout << "\nattn_fwd(ck): Converting BSHD to THD\n"; } } - //determine the o buffer based on workspace next section - devPtrOWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); - // reset the final results since padded places need to be 0 - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrO, 0, o_storage_bytes, stream)); + if(is_batch && is_padding){ + // reset the final results since padded places need to be 0 + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrO, 0, o_storage_bytes, stream)); + } if (nvte_log_ck_config) { std::cout<(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); - NVTE_CHECK_CUDA(cudaMemsetAsync(lse_workspace, 0, h*max_tokens_q*sizeof(float), stream)); // The next section are for dq_acc_ptr void* dq_acc_ptr = workspace_next; From b1094c698be04bad1a12d4d492147137351b1309 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 5 Nov 2025 14:36:19 -0600 Subject: [PATCH 21/30] Remove V3 API check and guard memsets --- .../include/ck_fused_attn/ck_fused_attn.hpp | 2 - .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 12 +-- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 12 +-- .../common/fused_attn_rocm/fused_attn_ck.cpp | 79 ++----------------- 4 files changed, 15 insertions(+), 90 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index dddb81adb..54ee94786 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -85,7 +85,6 @@ hipError_t ck_attn_varlen_fwd( void* lse_thd_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, - bool is_v3_api_check, hipStream_t stream); hipError_t ck_attn_bwd( @@ -167,7 +166,6 @@ hipError_t ck_attn_varlen_bwd( bool uses_bwd_v3, bool is_v3_atomic_fp32, int how_v3_bf16_cvt, - bool is_v3_api_check, hipStream_t stream); }//namespace ck_fused_attn diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 03f775d34..db04aa5b7 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -858,7 +858,6 @@ hipError_t ck_attn_varlen_bwd( bool uses_bwd_v3, bool is_v3_atomic_fp32, int how_v3_bf16_cvt, - bool is_v3_api_check, hipStream_t stream){ bool has_dropout = (dropout_probability > 0.f); @@ -1033,7 +1032,7 @@ hipError_t ck_attn_varlen_bwd( // modify the max_seqlen_q for better performance in 0-length cases // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { - if(std::string(env_p) == "1" && !is_v3_api_check){ + if(std::string(env_p) == "1"){ if(ck_fused_attn_log_config){ std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } @@ -1061,12 +1060,9 @@ hipError_t ck_attn_varlen_bwd( uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, - uses_bwd_v3? cu_seqlen_q_padded_ptr: nullptr, - uses_bwd_v3? cu_seqlen_kv_padded_ptr: nullptr, - is_v3_api_check); - if(is_v3_api_check){ - return (hipError_t)(average_runtime_or_v3_check_status > 0); - } + nullptr, + nullptr, + false); if(average_runtime_or_v3_check_status < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index e95032204..1751faeb7 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -324,7 +324,6 @@ hipError_t ck_attn_varlen_fwd( void* lse_thd_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, - bool is_v3_api_check, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); @@ -458,7 +457,7 @@ hipError_t ck_attn_varlen_fwd( // modify the max_seqlen_q for better performance in 0-length cases // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ - if(std::string(env_p) == "1" && !is_v3_api_check){ + if(std::string(env_p) == "1"){ if(ck_fused_attn_log_config){ std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } @@ -482,12 +481,9 @@ hipError_t ck_attn_varlen_fwd( has_lse, uses_fwd_v3, how_v3_bf16_cvt, - uses_fwd_v3? cu_seqlen_q_padded_ptr: nullptr, - uses_fwd_v3? cu_seqlen_kv_padded_ptr: nullptr, - is_v3_api_check); - if(is_v3_api_check){ - return (hipError_t)(average_runtime_or_v3_check_status > 0); - } + nullptr, + nullptr, + false); if(average_runtime_or_v3_check_status < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 0c907f3d5..96b089a17 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -744,36 +744,12 @@ void fused_attn_ck_fwd_impl( devPtrSoftmaxLSEWithoutPadding, nvte_ck_uses_fwd_v3, nvte_ck_how_v3_bf16_cvt, - false, stream)); // add padding for o and softmax_lse add_padding(dtype, b, h, s_q, d_v, max_tokens_q, false, o_stride[0], o_stride[1], o_stride[2], devPtrOWithoutPadding, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrO, stream); add_padding_softmax_lse(b, h, s_q, max_tokens_q, false, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrSoftmaxAux, stream); }else if(bshd_to_thd || is_ragged){ using ck_fused_attn::ck_attn_varlen_fwd; - bool is_v3_supported = ck_attn_varlen_fwd( - nvte_to_ck_dtype(dtype), - b, h, hg, s_q, s_kv, d_qk, d_v, - max_tokens_q, - devPtrQ, - q_stride[1], q_stride[2], - devPtrK, - k_stride[1], k_stride[2], - devPtrV, - v_stride[1], v_stride[2], - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - ptrSeqOffsetsQ, ptrSeqOffsetsKV, - is_training, scaling_factor, dropout_probability, - devPtrDropoutSeed, devPtrDropoutOffset, - set_ck_mask(mask_type, window_size_left, window_size_right), - window_size_left, window_size_right, - devPtrO, - o_stride[1], o_stride[2], - devPtrSoftmaxLSEWithoutPadding, - nvte_ck_uses_fwd_v3, - nvte_ck_how_v3_bf16_cvt, - true, // check whether v3 is supported - stream)==1; NVTE_CHECK_CUDA( ck_attn_varlen_fwd( nvte_to_ck_dtype(dtype), @@ -794,9 +770,8 @@ void fused_attn_ck_fwd_impl( devPtrO, o_stride[1], o_stride[2], devPtrSoftmaxLSEWithoutPadding, - nvte_ck_uses_fwd_v3 && is_v3_supported, + nvte_ck_uses_fwd_v3, nvte_ck_how_v3_bf16_cvt, - false, stream)); // aiter asm output softmax_lse with padding add_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxLSEWithoutPadding, ptrSeqOffsetsQ, ptrSeqOffsetsQ, devPtrSoftmaxAux, stream); @@ -882,7 +857,6 @@ void fused_attn_ck_bwd_impl( size_t v_storage_bytes = max_tokens_kv*hg*d_v*nvte_dtype_size(dtype); size_t o_storage_bytes = max_tokens_q*h*d_v*nvte_dtype_size(dtype); - bool is_v3_supported = false; // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ size_t workspace_size_lse = max_tokens_q*h*sizeof(float); @@ -944,13 +918,15 @@ void fused_attn_ck_bwd_impl( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); if((layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) || (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D)){ // just memset all dq, dk, dv - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes + k_storage_bytes+ v_storage_bytes, stream)); + if(is_batch && is_padding){ + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes + k_storage_bytes+ v_storage_bytes, stream)); + } }else{ - // HD_2HD, HD_H2D, HD_HD_HD can just memset dq itself - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream)); //TODO: check whether we can remove this zero-out in padding/unpadding workaround // for pad between seqs case, we need to reset all dq, dk, dv if(is_batch && is_padding){ + // HD_2HD, HD_H2D, HD_HD_HD can just memset dq itself + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream)); if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ //kvpacked NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes + v_storage_bytes, stream)); @@ -1202,7 +1178,6 @@ void fused_attn_ck_bwd_impl( nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, - false, //v3_api_check, TODO: remove later stream)); // add padding for dq, dk, dv // dq, dk, dv of same shape as q, k, v @@ -1212,45 +1187,6 @@ void fused_attn_ck_bwd_impl( }else if(bshd_to_thd || is_ragged){ remove_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxAux, ptrSeqOffsetsQ, ptrSeqOffsetsQ, devPtrSoftmaxLSEWithoutPadding, stream); using ck_fused_attn::ck_attn_varlen_bwd; - bool is_v3_supported = ck_attn_varlen_bwd( - nvte_to_ck_dtype(dtype), - b, h, hg, s_q, s_kv, d_qk, d_v, - max_tokens_q, max_tokens_kv, - devPtrQ, - q_stride[1], q_stride[2], - devPtrK, - k_stride[1], k_stride[2], - devPtrV, - v_stride[1], v_stride[2], - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - ptrSeqOffsetsQ, ptrSeqOffsetsKV, - devPtrO, - o_stride[1], o_stride[2], - devPtrSoftmaxLSEWithoutPadding, - devPtrdO, - o_stride[1], o_stride[2], //dO and O share the same stride - scaling_factor, dropout_probability, - devPtrDropoutSeed, devPtrDropoutOffset, - set_ck_mask(mask_type, window_size_left, window_size_right), - window_size_left, window_size_right, - devPtrdQ, - q_stride[1], q_stride[2], //dQ and Q share the same stride - dq_acc_ptr, - dk_expanded_ptr, - dv_expanded_ptr, - dk_expanded_stride[1], dk_expanded_stride[2], //dK and K share the same stride - dv_expanded_stride[1], dv_expanded_stride[2], //dV and V share the same stride - devPtrdK, - k_stride[1], k_stride[2], //dK and K share the same stride - devPtrdV, - v_stride[1], v_stride[2], //dV and V share the same stride - lse_workspace, // softmax_lsed - deterministic, - nvte_ck_uses_bwd_v3, - nvte_ck_is_v3_atomic_fp32, - nvte_ck_how_v3_bf16_cvt, - true, //v3_api_check, TODO: remove later - stream)==1; NVTE_CHECK_CUDA( ck_attn_varlen_bwd( nvte_to_ck_dtype(dtype), @@ -1286,10 +1222,9 @@ void fused_attn_ck_bwd_impl( v_stride[1], v_stride[2], //dV and V share the same stride lse_workspace, // softmax_lsed deterministic, - nvte_ck_uses_bwd_v3 && is_v3_supported, + nvte_ck_uses_bwd_v3, nvte_ck_is_v3_atomic_fp32, nvte_ck_how_v3_bf16_cvt, - false, //v3_api_check, TODO: remove later stream)); }else{ using ck_fused_attn::ck_attn_bwd; From c3a0fce5801df51a423a5904de27f2bfe41698ba Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 6 Nov 2025 15:01:18 -0600 Subject: [PATCH 22/30] PR comments --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 22 +++-------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 22 +++-------- .../common/fused_attn_rocm/fused_attn_ck.cpp | 39 +++++++++---------- .../jax/csrc/extensions/attention.cpp | 2 +- 4 files changed, 32 insertions(+), 53 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index db04aa5b7..14baad367 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -344,8 +344,6 @@ void log_bwd_config(const char* func_name, const bool uses_bwd_v3, const bool is_v3_atomic_fp32, const int how_v3_bf16_cvt, - const void* cu_seqlen_q_padded_ptr, - const void* cu_seqlen_kv_padded_ptr, const fmha_bwd_args& fmha_args){ bool ck_fused_attn_log_config = false; @@ -371,8 +369,6 @@ void log_bwd_config(const char* func_name, std::cout<<"uses_bwd_v3: "<(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); - ptrSeqOffsetsKV = workspace_next; + devPtrCuSeqlenPaddedKV = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); - generate_cu_seqlen_padded(s_q, s_kv, b, ptrSeqOffsetsQ, ptrSeqOffsetsKV, stream); + generate_cu_seqlen_padded(s_q, s_kv, b, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedKV, stream); if(nvte_log_ck_config){ std::cout << "\nattn_fwd(ck): Converting BSHD to THD\n"; } @@ -716,9 +716,9 @@ void fused_attn_ck_fwd_impl( } if(is_SBHD && is_padding){ // remove padding for q, k, v - remove_padding(dtype, b, h, s_q, d_qk, max_tokens_q, false, q_stride[0], q_stride[1], q_stride[2], devPtrQ, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrQWithoutPadding, stream); - remove_padding(dtype, b, hg, s_kv, d_qk, max_tokens_kv, false, k_stride[0], k_stride[1], k_stride[2], devPtrK, devPtrCuSeqlensKV, ptrSeqOffsetsKV, devPtrKWithoutPadding, stream); - remove_padding(dtype, b, hg, s_kv, d_v, max_tokens_kv, false, v_stride[0], v_stride[1], v_stride[2], devPtrV, devPtrCuSeqlensKV, ptrSeqOffsetsKV, devPtrVWithoutPadding, stream); + remove_padding(dtype, b, h, s_q, d_qk, max_tokens_q, false, q_stride[0], q_stride[1], q_stride[2], devPtrQ, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, devPtrQWithoutPadding, stream); + remove_padding(dtype, b, hg, s_kv, d_qk, max_tokens_kv, false, k_stride[0], k_stride[1], k_stride[2], devPtrK, devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, devPtrKWithoutPadding, stream); + remove_padding(dtype, b, hg, s_kv, d_v, max_tokens_kv, false, v_stride[0], v_stride[1], v_stride[2], devPtrV, devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, devPtrVWithoutPadding, stream); // call varlen api using without_padding ptrs // for BSHD/SBHD, after padding removal, THD require stride_s update using ck_fused_attn::ck_attn_varlen_fwd; @@ -746,8 +746,8 @@ void fused_attn_ck_fwd_impl( nvte_ck_how_v3_bf16_cvt, stream)); // add padding for o and softmax_lse - add_padding(dtype, b, h, s_q, d_v, max_tokens_q, false, o_stride[0], o_stride[1], o_stride[2], devPtrOWithoutPadding, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrO, stream); - add_padding_softmax_lse(b, h, s_q, max_tokens_q, false, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, ptrSeqOffsetsQ, devPtrSoftmaxAux, stream); + add_padding(dtype, b, h, s_q, d_v, max_tokens_q, false, o_stride[0], o_stride[1], o_stride[2], devPtrOWithoutPadding, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, devPtrO, stream); + add_padding_softmax_lse(b, h, s_q, max_tokens_q, false, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, devPtrSoftmaxAux, stream); }else if(bshd_to_thd || is_ragged){ using ck_fused_attn::ck_attn_varlen_fwd; NVTE_CHECK_CUDA( @@ -762,7 +762,7 @@ void fused_attn_ck_fwd_impl( devPtrV, v_stride[1], v_stride[2], devPtrCuSeqlensQ, devPtrCuSeqlensKV, - ptrSeqOffsetsQ, ptrSeqOffsetsKV, + devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedKV, is_training, scaling_factor, dropout_probability, devPtrDropoutSeed, devPtrDropoutOffset, set_ck_mask(mask_type, window_size_left, window_size_right), @@ -774,7 +774,7 @@ void fused_attn_ck_fwd_impl( nvte_ck_how_v3_bf16_cvt, stream)); // aiter asm output softmax_lse with padding - add_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxLSEWithoutPadding, ptrSeqOffsetsQ, ptrSeqOffsetsQ, devPtrSoftmaxAux, stream); + add_padding_softmax_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedQ, devPtrSoftmaxAux, stream); }else{ using ck_fused_attn::ck_attn_fwd; NVTE_CHECK_CUDA( @@ -1014,8 +1014,8 @@ void fused_attn_ck_bwd_impl( void* devPtrdQWithoutPadding = nullptr; void* devPtrdKWithoutPadding = nullptr; void* devPtrdVWithoutPadding = nullptr; - void* ptrSeqOffsetsQ = devPtrSeqOffsetsQ; - void* ptrSeqOffsetsKV = devPtrSeqOffsetsKV; + void* devPtrCuSeqlenPaddedQ = devPtrSeqOffsetsQ; + void* devPtrCuSeqlenPaddedKV = devPtrSeqOffsetsKV; devPtrSoftmaxLSEWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); @@ -1082,17 +1082,16 @@ void fused_attn_ck_bwd_impl( } }else if(bshd_to_thd){ // cu_seqlen_padded ptrs for THD conversion - ptrSeqOffsetsQ = workspace_next; + devPtrCuSeqlenPaddedQ = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); - ptrSeqOffsetsKV = workspace_next; + devPtrCuSeqlenPaddedKV = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); - generate_cu_seqlen_padded(s_q, s_kv, b, ptrSeqOffsetsQ, ptrSeqOffsetsKV, stream); + generate_cu_seqlen_padded(s_q, s_kv, b, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedKV, stream); if(nvte_log_ck_config){ std::cout << "\nattn_bwd(ck): Converting BSHD to THD\n"; } } - if (nvte_log_ck_config) { std::cout<{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ size_t num_segments = input_batch; \ - if(is_ragged){ \ + if (is_ragged) { \ auto cudnn_runtime_version = cudnnGetVersion(); \ num_segments = input_batch * max_segments_per_seq; \ bool use_runtime_num_segments_check = false; \ From 9ab8df4a71997cb2d82609e791948dd7cc17bb0b Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 10 Nov 2025 12:30:26 -0600 Subject: [PATCH 23/30] Updated documentation --- README.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.rst b/README.rst index 74e72efde..e7889c4ec 100644 --- a/README.rst +++ b/README.rst @@ -264,6 +264,14 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case for both the `FusedAttention` and `DotProductAttention` modules. +Certain runtime calculations can be enabled to potentially optimize workloads depending on the nature of the inputs: + +* NVTE_CK_RUNTIME_NUM_SEGMENTS - by default 0, if set to 1 then the JAX integration will calculate the number of +segments at runtime. Enabling this requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`. +* NVTE_CK_RUNTIME_MAX_SEQLEN - by default 0, if set to 1 then the max sequence length will be calculated at runtime. +This can result in speedups in cases where there are many zero-length sequences. Enabling this while using the JAX +integration requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`. + FA v3 Kernels in CK Backend ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs. From 2adfb6ed6463e1e26457001b88edfe36ff6a8736 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 10 Nov 2025 14:31:35 -0600 Subject: [PATCH 24/30] PR review reconciliation - Updated debug message for BSHD-->THD conversion - Added env variable to gate FWD output memset for padding - Removed guards on memsets for d{Q,K,V} matrices --- README.rst | 4 ++- .../common/fused_attn_rocm/fused_attn_ck.cpp | 35 ++++++++----------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index e7889c4ec..703eec5a3 100644 --- a/README.rst +++ b/README.rst @@ -264,13 +264,15 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case for both the `FusedAttention` and `DotProductAttention` modules. -Certain runtime calculations can be enabled to potentially optimize workloads depending on the nature of the inputs: +Certain settings can be enabled to potentially optimize workloads depending on the nature of the inputs and expected outputs: * NVTE_CK_RUNTIME_NUM_SEGMENTS - by default 0, if set to 1 then the JAX integration will calculate the number of segments at runtime. Enabling this requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`. * NVTE_CK_RUNTIME_MAX_SEQLEN - by default 0, if set to 1 then the max sequence length will be calculated at runtime. This can result in speedups in cases where there are many zero-length sequences. Enabling this while using the JAX integration requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`. +* NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized +to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding. FA v3 Kernels in CK Backend ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index aa5dbae19..66938cf3c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -559,6 +559,7 @@ void fused_attn_ck_fwd_impl( bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0); int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); + bool nvte_ck_zero_out_pad = getenv("NVTE_CK_ZERO_OUT_PAD", 1); bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; bool is_SBHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD; @@ -672,10 +673,10 @@ void fused_attn_ck_fwd_impl( workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); generate_cu_seqlen_padded(s_q, s_kv, b, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedKV, stream); if(nvte_log_ck_config){ - std::cout << "\nattn_fwd(ck): Converting BSHD to THD\n"; + std::cout << "\nattn_fwd(ck): generating cu_seqlen_padded in BSHD+padding to THD+padding conversion.\n"; } } - if(is_batch && is_padding){ + if(is_padding && nvte_ck_zero_out_pad){ // reset the final results since padded places need to be 0 NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrO, 0, o_storage_bytes, stream)); } @@ -918,23 +919,17 @@ void fused_attn_ck_bwd_impl( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); if((layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) || (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D)){ // just memset all dq, dk, dv - if(is_batch && is_padding){ - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes + k_storage_bytes+ v_storage_bytes, stream)); - } + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes + k_storage_bytes+ v_storage_bytes, stream)); }else{ - //TODO: check whether we can remove this zero-out in padding/unpadding workaround - // for pad between seqs case, we need to reset all dq, dk, dv - if(is_batch && is_padding){ - // HD_2HD, HD_H2D, HD_HD_HD can just memset dq itself - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream)); - if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ - //kvpacked - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes + v_storage_bytes, stream)); - }else{ - //q, k, v separated - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes, stream)); - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdV, 0, v_storage_bytes, stream)); - } + // HD_2HD, HD_H2D, HD_HD_HD can just memset dq itself + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream)); + if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ + //kvpacked + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes + v_storage_bytes, stream)); + }else{ + //q, k, v separated + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdV, 0, v_storage_bytes, stream)); } } @@ -977,8 +972,6 @@ void fused_attn_ck_bwd_impl( } else{ NVTE_ERROR("NVTE_3HD NVTE_H3D should have h=hg."); } - // zeroing out dkv expanded in case CK requires that - NVTE_CHECK_CUDA(cudaMemsetAsync(dk_expanded_ptr, 0, nvte_dtype_size(dtype)*max_tokens_kv*h*(d_qk+d_v), stream)); } void* devPtrAlibiSlope = nullptr; @@ -1088,7 +1081,7 @@ void fused_attn_ck_bwd_impl( workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); generate_cu_seqlen_padded(s_q, s_kv, b, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedKV, stream); if(nvte_log_ck_config){ - std::cout << "\nattn_bwd(ck): Converting BSHD to THD\n"; + std::cout << "\nattn_bwd(ck): generating cu_seqlen_padded in BSHD+padding to THD+padding conversion.\n"; } } From bb3868d3a1d5d8143ff18bda787d8745d6904081 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 12 Nov 2025 11:09:05 -0600 Subject: [PATCH 25/30] Added explicit test --- tests/pytorch/fused_attn/test_fused_attn.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 7c2a09a7d..864cd33fb 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -277,6 +277,27 @@ def test(): param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") +def test_gqa_mla_thd(): + """ + Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration + post-processing for BWD FA with native padding support. + """ + config = ModelConfig(8, 16, 4, 128, 128, 128, 0.0, "padding", "no_bias", head_dim_v=64) + qkv_layout = "thd_thd_thd" + dtype = torch.float16 + _, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=True, + ) + if FusedAttnBackend["CK"] not in fused_attn_backends: + pytest.skip("This test requires the CK fused attention backend.") + + test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False) + @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_dot_product_mem_calc(): """ From 6206d588272a44b8140e989d6171033393549c22 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 13 Nov 2025 11:49:23 -0600 Subject: [PATCH 26/30] Formatting for bwd debug --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 14baad367..dcc758151 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -355,7 +355,7 @@ void log_bwd_config(const char* func_name, std::cout< Date: Fri, 14 Nov 2025 15:30:07 -0600 Subject: [PATCH 27/30] Resolved error when using mixed formats e.g. sbhd_2bshd --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 66938cf3c..cc4525604 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -560,10 +560,10 @@ void fused_attn_ck_fwd_impl( bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0); int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); bool nvte_ck_zero_out_pad = getenv("NVTE_CK_ZERO_OUT_PAD", 1); - - bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; - bool is_SBHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_SBHD; - bool is_BSHD = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_BSHD; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); + bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; + bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD; + bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD; bool is_batch = is_BSHD || is_SBHD; bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || @@ -689,6 +689,8 @@ void fused_attn_ck_fwd_impl( std::cout<<"is_batch: "< Date: Fri, 14 Nov 2025 15:31:26 -0600 Subject: [PATCH 28/30] Updated guard on flash-attention forced support --- tests/pytorch/fused_attn/test_fused_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index c4be44fb7..899207a6e 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -389,6 +389,7 @@ def test_dot_product_attention( and config.attn_mask_type in ["causal", "padding_causal"] ) and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) + and not is_mla ): flash_attn_supported = True From 85bb6f6e56c207e4c22800afa5c81fe8a9d1b21f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 14 Nov 2025 15:40:51 -0600 Subject: [PATCH 29/30] Added check for SBHD_2BSHD --- transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index cc4525604..ea91a613f 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -562,7 +562,7 @@ void fused_attn_ck_fwd_impl( bool nvte_ck_zero_out_pad = getenv("NVTE_CK_ZERO_OUT_PAD", 1); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; - bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD; + bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD; bool is_batch = is_BSHD || is_SBHD; @@ -846,7 +846,7 @@ void fused_attn_ck_bwd_impl( NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; - bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD; + bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD; bool is_batch = is_BSHD || is_SBHD; bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || From a12105d580e9526bd0ae3e9c8c4b0167e5285d7c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 14 Nov 2025 15:45:08 -0600 Subject: [PATCH 30/30] Added guard on dk/dv memset --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index ea91a613f..ced96ac4d 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -838,7 +838,8 @@ void fused_attn_ck_bwd_impl( bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 0); bool nvte_ck_is_v3_atomic_fp32 = getenv("NVTE_CK_IS_V3_ATOMIC_FP32", 1); int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); - + bool nvte_ck_zero_out_pad = getenv("NVTE_CK_ZERO_OUT_PAD", 1); + bool is_mqa_gqa = (h > hg); size_t kN0 = (d_qk <= 128)? 128:64; @@ -926,13 +927,15 @@ void fused_attn_ck_bwd_impl( }else{ // HD_2HD, HD_H2D, HD_HD_HD can just memset dq itself NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream)); - if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ - //kvpacked - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes + v_storage_bytes, stream)); - }else{ - //q, k, v separated - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes, stream)); - NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdV, 0, v_storage_bytes, stream)); + if(is_padding && nvte_ck_zero_out_pad){ + if(layout_group==NVTE_QKV_Layout_Group::NVTE_HD_2HD ||layout_group==NVTE_QKV_Layout_Group::NVTE_HD_H2D){ + //kvpacked + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes + v_storage_bytes, stream)); + }else{ + //q, k, v separated + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdK, 0, k_storage_bytes, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdV, 0, v_storage_bytes, stream)); + } } }