diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index bdbc97517..5ccabc97d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -8,7 +8,6 @@ cmake_minimum_required(VERSION 3.21) option(USE_ROCM "Use ROCm" ON) option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON) -option(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS "Build AOTriton GPU kernels" OFF) option(USE_FUSED_ATTN_CK "Use ck backend" ON) set(USE_CUDA OFF) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 954660e87..031b01a1c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -492,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con fused_attn_aotriton_bwd_qkvpacked( b, h, max_seqlen, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_QKV, input_O, input_dO, output_S, output_dQKV, @@ -678,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_aotriton_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, input_dO, output_S, @@ -858,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_aotriton_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, output_S, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index a8a151b40..44d566a8a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -296,15 +296,26 @@ void fused_attn_aotriton_fwd_impl( NVTE_CHECK_CUDA(attn_fwd(fwd_params, fwd_params.kVersion, stream)); } +// A thin conversion wrapper around eager tensor-views to lazy tensors +template +struct LazyTensorFunctions { + static aotriton::TensorView acquire(void* cookie) { + return *static_cast*>(cookie); + } + static void dispose(void* cookie) { + } +}; + void fused_attn_aotriton_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrO, void* devPtrSoftmaxAux, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, + void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, aotriton::DType dtype, @@ -312,16 +323,25 @@ void fused_attn_aotriton_bwd_impl( size_t *workspace_size, cudaStream_t stream) { + const uint64_t dq_acc_size = b*s_q*h*d*sizeof(float); + // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ - // CK only requires workspace for lse softmax + // AOTriton requires workspace for lse softmax *workspace_size = b*h*s_q*sizeof(float); + // AOTriton requires workspace for DQ_ACC + *workspace_size += dq_acc_size; return; } + void * delta = workspace; + workspace = static_cast(static_cast(workspace) + b*h*s_q*sizeof(float)); + void * dq_acc_ptr = workspace; + std::array q_stride; std::array k_stride; std::array v_stride; std::array o_stride; + std::array dq_acc_stride; generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), @@ -330,6 +350,9 @@ void fused_attn_aotriton_bwd_impl( layout, NVTE_QKV_Matrix::NVTE_V_Matrix); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + // AOTriton expects a BSHD layout DQ_ACC matrix + generateMatrixStrides(b, h, s_q, s_kv, d, dq_acc_stride.data(), + NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Matrix::NVTE_Q_Matrix); //q and o are having the same shape //k and v are having the same shape @@ -337,7 +360,7 @@ void fused_attn_aotriton_bwd_impl( std::array q_shape{b, h, s_q, d}; std::array kv_shape{b, hg, s_kv, d}; - // m and workspace are of the same shape and stride + // m and softmax_lse are of the same shape and stride std::array m_shape{b * h, s_q}; std::array m_stride{s_q, 1}; @@ -355,13 +378,45 @@ void fused_attn_aotriton_bwd_impl( // auxilary tensors auto M_tensor = aotriton::TensorView<2>(reinterpret_cast(devPtrSoftmaxAux), m_shape, m_stride, aotriton::DType::kFloat32); - auto wkspace_tensor = aotriton::TensorView<2>(reinterpret_cast(workspace), m_shape, m_stride, aotriton::DType::kFloat32); + auto delta_tensor = aotriton::TensorView<2>(reinterpret_cast(delta), m_shape, m_stride, aotriton::DType::kFloat32); + auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, dq_acc_stride, aotriton::DType::kFloat32); + NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream)); + + auto dq_acc_lazy = aotriton::LazyTensor<4> { + .cookie = &dq_acc_tensor, + .acquire = &LazyTensorFunctions<4>::acquire, + .dispose = &LazyTensorFunctions<4>::dispose + }; + auto delta_lazy = aotriton::LazyTensor<2> { + .cookie = &delta_tensor, + .acquire = &LazyTensorFunctions<2>::acquire, + .dispose = &LazyTensorFunctions<2>::dispose + }; + + // Cumulative seqlen tensors + std::array cu_seqlens_shape{b+1}; + std::array cu_seqlens_stride{1}; + auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); + auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); bool nvte_log_aotriton_config = false; if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") nvte_log_aotriton_config = true; } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); + using aotriton::v2::flash::attn_bwd; + auto seed = mk_aoscalartensor(devPtrDropoutSeed); + auto offset = mk_aoscalartensor(devPtrDropoutOffset); + const auto is_causal = mask_type == NVTE_CAUSAL_MASK; + + using aotriton::v3::flash::VarlenType; + int8_t varlen_type = VarlenType::None; + + auto [window_left, window_right] = get_window_sizes(window_size_left, window_size_right, is_causal); + using aotriton::v3::flash::CausalType; + int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (nvte_log_aotriton_config) { std::cout< empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v2::flash::attn_bwd; - auto seed = mk_aoscalartensor(devPtrDropoutSeed); - auto offset = mk_aoscalartensor(devPtrDropoutOffset); - const auto is_causal = mask_type == NVTE_CAUSAL_MASK; - NVTE_CHECK_CUDA(attn_bwd(q_tensor, - k_tensor, - v_tensor, - empty_bias, - scaling_factor, - o_tensor, - do_tensor, - dq_tensor, - dk_tensor, - dv_tensor, - empty_bias, - M_tensor, - wkspace_tensor, - dropout_probability, - seed, - offset, - 0, - is_causal, - stream)); + aotriton::v3::flash::attn_bwd_params bwd_params{}; + bwd_params.Q = q_tensor; + bwd_params.K = k_tensor; + bwd_params.V = v_tensor; + bwd_params.B = empty_bias; + bwd_params.Sm_scale = scaling_factor; + bwd_params.Out = o_tensor; + if(varlen_type){ + bwd_params.cu_seqlens_q = cu_seqlens_q; + bwd_params.cu_seqlens_k = cu_seqlens_k; + bwd_params.Max_seqlen_q = s_q; + bwd_params.Max_seqlen_k = s_kv; + } + bwd_params.DO = do_tensor; + bwd_params.DK = dk_tensor; + bwd_params.DV = dv_tensor; + bwd_params.DQ = dq_tensor; + bwd_params.DB = empty_bias; + bwd_params.L = M_tensor; + bwd_params.D = delta_lazy; + bwd_params.dropout_p = dropout_probability; + bwd_params.philox_seed_ptr = seed; + bwd_params.philox_offset1 = offset; + bwd_params.philox_offset2 = 0; + bwd_params.causal_type = causal_type; + bwd_params.varlen_type = varlen_type; + bwd_params.window_left = window_left; + bwd_params.window_right = window_right; + bwd_params.DQ_ACC = dq_acc_lazy; + + NVTE_CHECK_CUDA(attn_bwd(bwd_params, bwd_params.kVersion, stream)); } #endif // USE_FUSED_ATTN_AOTRITON } // namespace fused_attn_rocm @@ -495,6 +582,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -532,12 +620,14 @@ void fused_attn_aotriton_bwd_qkvpacked( fused_attn_aotriton_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -652,6 +742,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -692,12 +783,14 @@ void fused_attn_aotriton_bwd_kvpacked( fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -803,6 +896,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -831,12 +925,14 @@ void fused_attn_aotriton_bwd( fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index b016acc67..3fdb359d1 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -47,6 +47,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -72,6 +73,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -98,6 +100,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S,