diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 03068ab..040eba1 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -190,6 +190,7 @@ void set_params_dgrad( const float softcap, bool has_mask, bool has_bias, + bool accum_dbias, bool deterministic, const bool unpadded_lse ) { @@ -245,6 +246,8 @@ void set_params_dgrad( // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; + params.accum_dbias = accum_dbias; + params.deterministic = deterministic; } @@ -977,12 +980,13 @@ mha_bwd( ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) : dv; dbias_expanded = has_bias - ? ( - (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q - ? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) - : dbias - ) + ? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q + ? (seqlen_q_bias == 1) + ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat)) + : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) + : dbias : torch::empty({0}, opts); + bool accum_dbias = has_bias && (seqlen_q_bias == 1 && seqlen_q != 1); Flash_bwd_params params; @@ -1009,6 +1013,7 @@ mha_bwd( softcap, has_mask, has_bias, + accum_dbias, deterministic, /*unpadded_lse*/false ); @@ -1036,9 +1041,10 @@ mha_bwd( if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) { at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); } else { - dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); - if (seqlen_q_bias == 1) { - dbias_expanded = at::sum(dbias_expanded, {2}, true); + if (accum_dbias) { + dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2}); + } else { + dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); } if (batch_size_bias == 1) { dbias_expanded = at::sum(dbias_expanded, {0}, true); @@ -1238,6 +1244,7 @@ mha_varlen_bwd( softcap, has_mask, has_bias, + /*accum_dbias*/false, deterministic, /*unpadded_lse*/true ); diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index a1c9bf1..29c342f 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -195,6 +195,8 @@ struct Flash_bwd_params : public Flash_fwd_params { bool deterministic; index_t dq_accum_split_stride; + + bool accum_dbias; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index e8578e2..7c4056e 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -159,6 +159,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Shape, Int>{}, make_stride(params.dbias_row_stride, _1{}) ); + [[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr; + if constexpr (Has_bias) { + gdBias_accum_ptr = reinterpret_cast(params.dbias_ptr) + row_offset_dbias; + } Tensor gdO = make_tensor( make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, @@ -848,12 +852,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in __syncthreads(); if constexpr (Has_bias) { // Write dS to dBias - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_dBias, - tBiassBias, tdBiasgdBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); + if (!params.accum_dbias) { + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_dBias, + tBiassBias, tdBiasgdBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + } else { + #pragma unroll + for (int m = 0; m < size<1>(tBiassBias); ++m) { + if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { + #pragma unroll + for (int n = 0; n < size<2>(tBiassBias); ++n) { + if (Is_even_MN || tBiaspBias(n)) { + #pragma unroll + for (int i = 0; i < size<0>(tBiassBias); ++i) { + const auto coord = tBiascBias(i, m, n); + const int row = get<0>(coord); + const int col = get<1>(coord); + if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) { + atomicAdd( + gdBias_accum_ptr + row * params.dbias_row_stride + col, + static_cast(tBiassBias(i, m, n)) + ); + } + } + } + } + } + } + } } // if (cute::thread0()) { print(tPrP); } @@ -994,6 +1023,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Advance gBias and gdBias tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); + if (params.accum_dbias) { + gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride); + } if (any_active_next) { FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias,