diff --git a/3rdparty/aiter b/3rdparty/aiter index 7a41cca67..e6cef50bf 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit 7a41cca67187bd5f77c337765a1a289337901cef +Subproject commit e6cef50bf69a1c51f285210921572198daddfe6a diff --git a/README.rst b/README.rst index 74e72efde..703eec5a3 100644 --- a/README.rst +++ b/README.rst @@ -264,6 +264,16 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case for both the `FusedAttention` and `DotProductAttention` modules. +Certain settings can be enabled to potentially optimize workloads depending on the nature of the inputs and expected outputs: + +* NVTE_CK_RUNTIME_NUM_SEGMENTS - by default 0, if set to 1 then the JAX integration will calculate the number of +segments at runtime. Enabling this requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`. +* NVTE_CK_RUNTIME_MAX_SEQLEN - by default 0, if set to 1 then the max sequence length will be calculated at runtime. +This can result in speedups in cases where there are many zero-length sequences. Enabling this while using the JAX +integration requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`. +* NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized +to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding. + FA v3 Kernels in CK Backend ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs. diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 43f1d4488..899207a6e 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -277,6 +277,27 @@ def test(): param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") +def test_gqa_mla_thd(): + """ + Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration + post-processing for BWD FA with native padding support. + """ + config = ModelConfig(8, 16, 4, 128, 128, 128, 0.0, "padding", "no_bias", head_dim_v=64) + qkv_layout = "thd_thd_thd" + dtype = torch.float16 + _, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=True, + ) + if FusedAttnBackend["CK"] not in fused_attn_backends: + pytest.skip("This test requires the CK fused attention backend.") + + test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False) + @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_dot_product_mem_calc(): """ @@ -368,6 +389,7 @@ def test_dot_product_attention( and config.attn_mask_type in ["causal", "padding_causal"] ) and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) + and not is_mla ): flash_attn_supported = True diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index d174a48f4..54ee94786 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -59,6 +59,7 @@ hipError_t ck_attn_fwd( uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_ptr, bool uses_fwd_v3, + int how_v3_bf16_cvt, hipStream_t stream); hipError_t ck_attn_varlen_fwd( @@ -72,6 +73,7 @@ hipError_t ck_attn_varlen_fwd( const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, bool is_training, float scaling_factor, float dropout_probability, @@ -82,6 +84,7 @@ hipError_t ck_attn_varlen_fwd( uint64_t stride_h_o, uint64_t stride_s_o, void* lse_thd_ptr, bool uses_fwd_v3, + int how_v3_bf16_cvt, hipStream_t stream); hipError_t ck_attn_bwd( @@ -137,6 +140,7 @@ hipError_t ck_attn_varlen_bwd( const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, const void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, const void* lse_thd_ptr, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 2b717ace0..dcc758151 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -15,6 +15,24 @@ namespace ck_fused_attn{ +// TODO: unify with binary search in TE/common/fused_attn(rocm)/util +// no device std::upper_bound +// in an increasing array with given size len, search for the index that: +// array[index] <= target < array[index+1] +// guaranteed that target >=0 and target <= cu_seqlen[end-1] +__forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + // define dk_dv_reduce function only for fp16 and bf16 types template __global__ void dk_dv_reduce( @@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce( // define dk_dv_reduce function in THD layout only for fp16 and bf16 types template __global__ void dk_dv_reduce_thd( - uint64_t h, uint64_t hg, uint64_t d, - const int32_t* total_seqlen_kv_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t d, + const int32_t* cu_seqlen_kv_ptr, + const int32_t* cu_seqlen_kv_padded_ptr, const DataType *dk_expanded, const DataType *dv_expanded, uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded, @@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd( uint64_t hdim_idx = threadIdx.x; assert(hdim_idx= *total_seqlen_kv_ptr){ + + if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } - + if(cu_seqlen_kv_padded_ptr){ + uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); + uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; + if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ + return; + } + } // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd( // When d_qk != d_v, we need to reduce dk and dv separately template __global__ void dk_or_dv_reduce_thd( - uint64_t h, uint64_t hg, uint64_t d, - const int32_t* total_seqlen_kv_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t d, + const int32_t* cu_seqlen_kv_ptr, + const int32_t* cu_seqlen_kv_padded_ptr, const DataType *dk_or_dv_expanded, uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded, DataType *dk_or_dv, @@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd( assert(hdim_idx= *total_seqlen_kv_ptr){ + if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } - + if(cu_seqlen_kv_padded_ptr){ + uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); + uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; + if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ + return; + } + } // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -323,7 +355,7 @@ void log_bwd_config(const char* func_name, std::cout<{philox_seed_ptr, philox_offset_ptr}}; }(); + // modify the max_seqlen_q for better performance in 0-length cases + // lse_thd_ptr used as buffer + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { + if(std::string(env_p) == "1"){ + if(ck_fused_attn_log_config){ + std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); + fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); + } + } + // print ck traits and args when needed log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); if (uses_bwd_v3) @@ -985,17 +1042,17 @@ hipError_t ck_attn_varlen_bwd( } float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_enum::no_bias, - has_dbias, - s_randval, - deterministic, - uses_bwd_v3, - is_v3_atomic_fp32, - how_v3_bf16_cvt); + stream_config, + data_type_str, + is_group_mode, + mask_type, + bias_enum::no_bias, + has_dbias, + s_randval, + deterministic, + uses_bwd_v3, + is_v3_atomic_fp32, + how_v3_bf16_cvt); if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); @@ -1006,6 +1063,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block(d_qk); if (ck_fused_attn_log_config){ std::cout<, grid, block, 0, stream, - h, hg, d_qk, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_qk, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dk_expanded_ptr), static_cast(dv_expanded_ptr), stride_h_dk_expanded, stride_s_dk_expanded, @@ -1030,6 +1090,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block_dk(d_qk); if (ck_fused_attn_log_config){ std::cout<, grid, block_dk, 0, stream, - h, hg, d_qk, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_qk, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dk_expanded_ptr), stride_h_dk_expanded, stride_s_dk_expanded, static_cast(dk_ptr), @@ -1050,6 +1113,8 @@ hipError_t ck_attn_varlen_bwd( dim3 block_dv(d_v); if (ck_fused_attn_log_config){ std::cout<, grid, block_dv, 0, stream, - h, hg, d_v, - static_cast(cu_seqlen_kv_ptr)+b, + b, h, hg, d_v, + static_cast(cu_seqlen_kv_ptr), + static_cast(cu_seqlen_kv_padded_ptr), static_cast(dv_expanded_ptr), stride_h_dv_expanded, stride_s_dv_expanded, static_cast(dv_ptr), diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index c87a3db6c..9f87f8b13 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -27,6 +27,7 @@ void log_fwd_config(const char* func_name, const bool is_v_rowmajor, const bool do_fp8_static_quant, const bool uses_fwd_v3, + const bool how_v3_bf16_cvt, const fmha_fwd_args& fmha_args){ bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { @@ -37,22 +38,25 @@ void log_fwd_config(const char* func_name, std::cout<::type>(mask_type)<::type>(bias_type)<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))< 0.f); @@ -209,13 +224,12 @@ hipError_t ck_attn_fwd( nullptr,//rand_val_ptr lse_ptr, o_ptr, - nullptr, //cu_seqlen_q - nullptr, //cu_seqlen_kv nullptr, //seqstart_q_ptr nullptr, //seqstart_k_ptr + nullptr, //seqlen_q_ptr nullptr, //seqlen_k_ptr - nullptr, //seqstart_padded_q_ptr - nullptr, //seqstart_padded_k_ptr + nullptr, //cu_padded_q_ptr + nullptr, //cu_padded_k_ptr max_seqlen_q, max_seqlen_k, batch, @@ -258,7 +272,7 @@ hipError_t ck_attn_fwd( }(); // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, fmha_args); + log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); if (uses_fwd_v3) { set_aiter_asm_dir(); @@ -271,7 +285,8 @@ hipError_t ck_attn_fwd( mask_type, bias_type, has_lse, - uses_fwd_v3); + uses_fwd_v3, + how_v3_bf16_cvt); if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); @@ -290,6 +305,7 @@ hipError_t ck_attn_varlen_fwd( const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, bool is_training, float scaling_factor, float dropout_probability, @@ -300,6 +316,7 @@ hipError_t ck_attn_varlen_fwd( uint64_t stride_h_o, uint64_t stride_s_o, void* lse_thd_ptr, bool uses_fwd_v3, + int how_v3_bf16_cvt, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); @@ -384,13 +401,12 @@ hipError_t ck_attn_varlen_fwd( nullptr,//rand_val_ptr lse_thd_ptr, o_ptr, - nullptr, //cu_seqlen_q - nullptr, //cu_seqlen_kv - cu_seqlen_q_ptr, //seqstart_q_ptr - cu_seqlen_kv_ptr, //seqstart_k_ptr + cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr + cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr + nullptr, //seqlen_q_ptr nullptr, //seqlen_k_ptr - nullptr, //seqstart_padded_q_ptr - nullptr, //seqstart_padded_k_ptr + cu_seqlen_q_ptr, //cu_seqlen_q_ptr + cu_seqlen_kv_ptr, //cu_seqlen_k_ptr max_seqlen_q, //seqlen_q, unused in group mode max_seqlen_kv, //seqlen_kv, unused in group mode batch, @@ -431,22 +447,33 @@ hipError_t ck_attn_varlen_fwd( false, std::pair{philox_seed_ptr, philox_offset_ptr}}; }(); - + // modify the max_seqlen_q for better performance in 0-length cases + // lse_thd_ptr used as buffer + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ + if(std::string(env_p) == "1"){ + if(ck_fused_attn_log_config){ + std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); + } + } // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, fmha_args); + log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); if (uses_fwd_v3) { set_aiter_asm_dir(); } - float average_runtime = aiter::mha_fwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - uses_fwd_v3); + float average_runtime = aiter::mha_fwd( + fmha_args, + stream_config, + data_type_str, + is_group_mode, + mask_type, + bias_type, + has_lse, + uses_fwd_v3, + how_v3_bf16_cvt); if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index ab9d28b2e..56618335f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -100,4 +100,39 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, return std::make_pair(bias_type, bias_shape); } +__global__ void get_runtime_max_seqlen_kernel( + uint64_t b, + const int32_t* cu_seqlen_ptr, + const int32_t* cu_seqlen_padded_ptr, + uint64_t *out) { + + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if(tid >= b){ + return; + } + if(cu_seqlen_padded_ptr){ + atomicMax(out, cu_seqlen_padded_ptr[tid+1] - cu_seqlen_padded_ptr[tid]); + }else{ + atomicMax(out, cu_seqlen_ptr[tid+1] - cu_seqlen_ptr[tid]); + } +} + +uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream){ + uint64_t* runtime_max_seqlen_ptr = static_cast(workspace); + uint64_t runtime_max_seqlen; + //reset the result buffer to 0 + hipMemsetAsync(runtime_max_seqlen_ptr, 0, sizeof(uint64_t), stream); + constexpr int threads = 128; + // in case b ==0 + const int blocks = (static_cast(b) - 1) / threads + 1; // ceil + get_runtime_max_seqlen_kernel<<>>( + b, + static_cast(cu_seqlen_ptr), + static_cast(cu_seqlen_padded_ptr), + runtime_max_seqlen_ptr); + hipMemcpyAsync(&runtime_max_seqlen, runtime_max_seqlen_ptr, sizeof(uint64_t), hipMemcpyDeviceToHost, stream); + hipStreamSynchronize(stream); + return runtime_max_seqlen; +} + }//namespace ck_fused_attn diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index b0ce02ee0..d248d65ca 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -9,6 +9,7 @@ #include #include +#include //forward declaration for ck_tile enum enum class mask_enum; @@ -54,5 +55,7 @@ BiasShape get_bias_shape(uint64_t b, uint64_t h, uint64_t bias_b, uint64_t bias_ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t b, uint64_t h, uint64_t bias_b, uint64_t bias_h); void set_aiter_asm_dir(); +uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 66fa72c0c..a56a3e4c4 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -868,10 +868,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } } -uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, +uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t max_batch_size, cudaStream_t stream) { NVTE_API_CALL(nvte_get_runtime_num_segments); - return transformer_engine::fused_attn_rocm::GetRuntimeNumSegments(cu_seqlen, workspace, len, stream); + return transformer_engine::fused_attn_rocm::GetRuntimeNumSegments(cu_seqlen, workspace, max_batch_size, stream); } void nvte_populate_rng_state_async(void *rng_state_dst, const void *const seed, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index b38249f5b..ced96ac4d 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -18,34 +18,6 @@ namespace transformer_engine { namespace fused_attn_rocm { -bool get_pad_between_seqs( - const Tensor* input_cu_seqlens, - const Tensor* input_cu_seqlens_padded, - bool is_ragged, bool is_padding -){ - // First we check whether we have a ragged array with a non-trivial - // input_cu_seqlens_padded tensor - bool pad_between_seqs = ( - is_ragged - && input_cu_seqlens->data.dptr!=input_cu_seqlens_padded->data.dptr - && !input_cu_seqlens_padded->data.shape.empty() - ); - // Next we guard against an initial workspace-allocation which occurs in the - // JAX TE extension. We check for both pointers being null while retaining - // shape data, indicating the use of dummy data in the allocation pass. - pad_between_seqs = pad_between_seqs || ( - is_ragged - && input_cu_seqlens->data.dptr==nullptr && !input_cu_seqlens->data.shape.empty() - && input_cu_seqlens_padded->data.dptr==nullptr && !input_cu_seqlens_padded->data.shape.empty() - ); - // Finally we check whether we have an array with padding and non-empty input_cu_seqlens - pad_between_seqs = pad_between_seqs || ( - !is_ragged - && is_padding - && !input_cu_seqlens->data.shape.empty() - ); - return pad_between_seqs; -} // check the fused attn config to see whether it's ck backend supported // single filtering followed by joint filtering bool is_ck_backend_supported( @@ -215,6 +187,34 @@ ck_fused_attn::MaskType set_ck_mask(NVTE_Mask_Type nvte_mask_type, int64_t nvte_ return ck_fused_attn::MaskType::window_generic; } +__global__ +void generate_cu_seqlen_padded_kernel( + uint32_t s_q, uint32_t s_kv, uint32_t b, + int32_t* cu_seqlen_q_padded_ptr, + int32_t* cu_seqlen_kv_padded_ptr +){ + for(int i = blockIdx.x * blockDim.x + threadIdx.x; i < b+1; i += blockDim.x * gridDim.x){ + cu_seqlen_q_padded_ptr[i] = s_q * i; + cu_seqlen_kv_padded_ptr[i] = s_kv * i; + } +} + +void generate_cu_seqlen_padded( + uint32_t s_q, uint32_t s_kv, uint32_t b, + void* cu_seqlen_q_padded_ptr, + void* cu_seqlen_kv_padded_ptr, + hipStream_t stream +){ + constexpr int THREADS_PER_BLOCK = 256; + dim3 block(THREADS_PER_BLOCK); + dim3 grid(ceil(1.0 * (b+1)/THREADS_PER_BLOCK)); + generate_cu_seqlen_padded_kernel<<>>( + s_q, s_kv, b, + static_cast(cu_seqlen_q_padded_ptr), + static_cast(cu_seqlen_kv_padded_ptr) + ); +} + __global__ void generate_alibi_slope(uint64_t h, float* alibi_slope_ptr){ for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < h; id += blockDim.x * gridDim.x){ @@ -536,7 +536,7 @@ void remove_padding_softmax_lse( // actual fwd implementation, calling ck api directly void fused_attn_ck_fwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - bool pad_between_seqs, size_t max_tokens_q, size_t max_tokens_kv, + size_t max_tokens_q, size_t max_tokens_kv, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -556,10 +556,21 @@ void fused_attn_ck_fwd_impl( if (env_p != nullptr && std::string(env_p) == "1") nvte_log_ck_config = true; } - bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0); - - bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0); + int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); + bool nvte_ck_zero_out_pad = getenv("NVTE_CK_ZERO_OUT_PAD", 1); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); + bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; + bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; + bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD; + bool is_batch = is_BSHD || is_SBHD; + + bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool bshd_to_thd = is_BSHD && is_padding; + // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case size_t q_storage_bytes = max_tokens_q*h*d_qk*nvte_dtype_size(dtype); @@ -573,15 +584,13 @@ void fused_attn_ck_fwd_impl( if(bias_type == NVTE_Bias_Type::NVTE_ALIBI){ (*workspace_size)+= h*sizeof(float); } - if(pad_between_seqs){ - // softmax_lse buffer - (*workspace_size)+= max_tokens_q*h*sizeof(float); + (*workspace_size)+= max_tokens_q*h*sizeof(float); + if(is_SBHD && is_padding){ // request q, k, v, o buffer without padding (*workspace_size)+= q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes; - }else if(is_ragged){ - // We include a softmax_lse buffer to use the kernel in order to properly reshape the lse as needed. - (*workspace_size)+= max_tokens_q*h*sizeof(float); - + }else if(bshd_to_thd){ + // cu_seqlen_padded buffers + (*workspace_size)+= 2*(b+1)*sizeof(int32_t); } if (nvte_log_ck_config) { std::cout<(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + void* devPtrCuSeqlenPaddedQ = devPtrSeqOffsetsQ; + void* devPtrCuSeqlenPaddedKV = devPtrSeqOffsetsKV; + + + // next h*max_tokens_q*sizeof(float) in workspace are for lse buffer + devPtrSoftmaxLSEWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + if(is_SBHD && is_padding){ + //determine the o buffer based on workspace next section + devPtrOWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); + //determine q, k ,v buffer based on the workspace next ptr and layout group NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); //Q ptr always comes at first @@ -648,15 +665,20 @@ void fused_attn_ck_fwd_impl( devPtrVWithoutPadding = workspace_next; workspace_next = static_cast(static_cast(workspace_next) + v_storage_bytes); } - //determine the o buffer based on workspace next section - devPtrOWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + o_storage_bytes); + }else if(bshd_to_thd){ + // cu_seqlen_padded ptrs for THD conversion + devPtrCuSeqlenPaddedQ = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + devPtrCuSeqlenPaddedKV = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + generate_cu_seqlen_padded(s_q, s_kv, b, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedKV, stream); + if(nvte_log_ck_config){ + std::cout << "\nattn_fwd(ck): generating cu_seqlen_padded in BSHD+padding to THD+padding conversion.\n"; + } + } + if(is_padding && nvte_ck_zero_out_pad){ // reset the final results since padded places need to be 0 NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrO, 0, o_storage_bytes, stream)); - }else if(is_ragged){ - // next h*max_tokens_q*sizeof(float) in workspace are for lse buffer - devPtrSoftmaxLSEWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); } if (nvte_log_ck_config) { @@ -664,7 +686,12 @@ void fused_attn_ck_fwd_impl( std::cout<<"layout: "<("NVTE_CK_USES_BWD_V3", 0); + bool nvte_ck_is_v3_atomic_fp32 = getenv("NVTE_CK_IS_V3_ATOMIC_FP32", 1); + int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); + bool nvte_ck_zero_out_pad = getenv("NVTE_CK_ZERO_OUT_PAD", 1); bool is_mqa_gqa = (h > hg); size_t kN0 = (d_qk <= 128)? 128:64; size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; - bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); + bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; + bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; + bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD; + bool is_batch = is_BSHD || is_SBHD; + bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool bshd_to_thd = is_BSHD && is_padding; + // extract the qkv and o storage bytes to allocate buffer for padding removing // b from cu_seqlen is not the actual storage batch for pad_between_seqs case size_t q_storage_bytes = max_tokens_q*h*d_qk*nvte_dtype_size(dtype); @@ -833,14 +880,17 @@ void fused_attn_ck_bwd_impl( //ck requires a buffer dbias_expanded of size BHSS if bias is not BHSS (*workspace_size) += b*h*s_q*s_kv*nvte_dtype_size(dtype); } - if(pad_between_seqs){ - // remove padding for the softmax_lse - (*workspace_size)+= h*max_tokens_q*sizeof(float); - // allocate the q, k, v, o, do, dq, dk, dv, - (*workspace_size)+= 2*(q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes); - }else if(is_ragged){ - // remove padding for the softmax_lse - (*workspace_size)+= h*max_tokens_q*sizeof(float); + // remove padding for the softmax_lse + (*workspace_size)+= h*max_tokens_q*sizeof(float); + if(is_SBHD && is_padding){ + // allocate the q, k, v, o, do, dq, dk, dv, + (*workspace_size)+= 2*(q_storage_bytes + k_storage_bytes + v_storage_bytes + o_storage_bytes); + if (nvte_log_ck_config) { + std::cout<(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + void* devPtrCuSeqlenPaddedQ = devPtrSeqOffsetsQ; + void* devPtrCuSeqlenPaddedKV = devPtrSeqOffsetsKV; + + devPtrSoftmaxLSEWithoutPadding = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + if(is_SBHD && is_padding){ //determine q, k, v buffer based on the workspace next ptr and layout group NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); //Q ptr always comes at first @@ -1030,22 +1079,27 @@ void fused_attn_ck_bwd_impl( // zeroing out just the dq itself NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQWithoutPadding, 0, q_storage_bytes, stream)); } - }else if(is_ragged){ - devPtrSoftmaxLSEWithoutPadding = workspace_next; - workspace_next = static_cast(static_cast(workspace_next) + h*max_tokens_q*sizeof(float)); + }else if(bshd_to_thd){ + // cu_seqlen_padded ptrs for THD conversion + devPtrCuSeqlenPaddedQ = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + devPtrCuSeqlenPaddedKV = workspace_next; + workspace_next = static_cast(static_cast(workspace_next) + (b+1)*sizeof(int32_t)); + generate_cu_seqlen_padded(s_q, s_kv, b, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedKV, stream); + if(nvte_log_ck_config){ + std::cout << "\nattn_bwd(ck): generating cu_seqlen_padded in BSHD+padding to THD+padding conversion.\n"; + } } - // bwd v3 is optional by enabling the following envs - // default values follows the ck example setting - bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 0); - bool nvte_ck_is_v3_atomic_fp32 = getenv("NVTE_CK_IS_V3_ATOMIC_FP32", 1); - int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); if (nvte_log_ck_config) { std::cout<data).shape.begin(), (input_QKV->data).shape.end(), static_cast(1), std::multiplies())/h/d/3; @@ -1433,7 +1483,7 @@ void fused_attn_ck_bwd_qkvpacked( fused_attn_ck_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, d, bias_b, bias_h, - pad_between_seqs, max_tokens, max_tokens, + max_tokens, max_tokens, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, @@ -1575,11 +1625,10 @@ void fused_attn_ck_fwd_kvpacked( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, bias_b, bias_h, - pad_between_seqs, max_tokens_q, max_tokens_kv, + max_tokens_q, max_tokens_kv, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, @@ -1676,7 +1725,6 @@ void fused_attn_ck_bwd_kvpacked( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); // extract the max_tokens for padding/unpadding and softmax_lse buffer // b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case @@ -1685,7 +1733,7 @@ void fused_attn_ck_bwd_kvpacked( fused_attn_ck_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, bias_b, bias_h, - pad_between_seqs, max_tokens_q, max_tokens_kv, + max_tokens_q, max_tokens_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, @@ -1817,11 +1865,10 @@ void fused_attn_ck_fwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, - pad_between_seqs, max_tokens_q, max_tokens_kv, + max_tokens_q, max_tokens_kv, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, @@ -1907,7 +1954,6 @@ void fused_attn_ck_bwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - bool pad_between_seqs = get_pad_between_seqs(input_cu_seqlens_q, input_cu_seqlens_q_padded, is_ragged, is_padding); // extract the max_tokens for padding/unpadding and softmax_lse buffer // b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case @@ -1916,7 +1962,7 @@ void fused_attn_ck_bwd( fused_attn_ck_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, - pad_between_seqs, max_tokens_q, max_tokens_kv, + max_tokens_q, max_tokens_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, diff --git a/transformer_engine/common/fused_attn_rocm/utils.cpp b/transformer_engine/common/fused_attn_rocm/utils.cpp index 5e9b0b67f..1902665d7 100644 --- a/transformer_engine/common/fused_attn_rocm/utils.cpp +++ b/transformer_engine/common/fused_attn_rocm/utils.cpp @@ -229,11 +229,11 @@ __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t rng_state_dst[1] = offset; } -__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) { +__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t max_batch_size, uint32_t *out) { int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= len) return; + if (tid >= max_batch_size) return; - if (cu_seqlen[tid] > 0) { + if (cu_seqlen[tid+1] - cu_seqlen[tid] > 0) { // atomicAdd only support 32 bits dtype atomicAdd(out, 1); } @@ -253,15 +253,15 @@ void PopulateRngStateAsync(void *rng_state_dst, NVTE_CHECK_CUDA(cudaGetLastError()); } -uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) { +uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t max_batch_size, cudaStream_t stream) { // workspace size requires 4 bytes uint32_t *dout = static_cast(workspace); uint32_t hout{}; cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream); constexpr int threads = 128; - const int blocks = (len - 1) / threads + 1; + const int blocks = (max_batch_size - 1) / threads + 1; // ceil get_runtime_num_segments_kernel<<>>(static_cast(cu_seqlen), - len, dout); + max_batch_size, dout); cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); return hout; diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index af1fcb493..283be077d 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -236,13 +236,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t num_segments = input_batch; \ if (is_ragged) { \ auto cudnn_runtime_version = cudnnGetVersion(); \ - if (cudnn_runtime_version >= 90300) { \ - num_segments = input_batch * max_segments_per_seq; \ - } else { \ + num_segments = input_batch * max_segments_per_seq; \ + bool use_runtime_num_segments_check = false; \ + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_NUM_SEGMENTS")){ \ + use_runtime_num_segments_check = std::string(env_p) == "1"; \ + } \ + if(cudnn_runtime_version < 90300 || use_runtime_num_segments_check){ \ size_t runtime_num_segments_q = nvte_get_runtime_num_segments( \ - q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \ + q_cu_seqlens, workspace, input_batch * max_segments_per_seq, stream); \ size_t runtime_num_segments_kv = nvte_get_runtime_num_segments( \ - kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \ + kv_cu_seqlens, workspace, input_batch * max_segments_per_seq, stream); \ NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \ NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \ num_segments = runtime_num_segments_q; \