Skip to content

Commit e5b045d

Browse files
committed
Add atomic bias-grad accumulation option
Adds an optional accumulation path for bias gradients using atomic updates when accumulation is enabled, avoiding overwrites when multiple tiles contribute. Keeps the existing fast write path when accumulation is disabled, respects sequence bounds, and correctly tracks the accumulation pointer across tile steps. Improves correctness for split/streamed backward passes where bias gradients must be aggregated across blocks.
1 parent 2e69c3d commit e5b045d

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
159159
Shape<Int<kBlockM>, Int<kBlockN>>{},
160160
make_stride(params.dbias_row_stride, _1{})
161161
);
162+
[[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr;
163+
if constexpr (Has_bias) {
164+
gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias;
165+
}
162166
Tensor gdO = make_tensor(
163167
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
164168
Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -848,12 +852,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
848852
__syncthreads();
849853
if constexpr (Has_bias) {
850854
// Write dS to dBias
851-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
852-
gmem_tiled_copy_dBias,
853-
tBiassBias, tdBiasgdBias,
854-
tBiascBias, tBiaspBias,
855-
binfo.actual_seqlen_q - m_block * kBlockM
856-
);
855+
if (!params.accum_dbias) {
856+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
857+
gmem_tiled_copy_dBias,
858+
tBiassBias, tdBiasgdBias,
859+
tBiascBias, tBiaspBias,
860+
binfo.actual_seqlen_q - m_block * kBlockM
861+
);
862+
} else {
863+
#pragma unroll
864+
for (int m = 0; m < size<1>(tBiassBias); ++m) {
865+
if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
866+
#pragma unroll
867+
for (int n = 0; n < size<2>(tBiassBias); ++n) {
868+
if (Is_even_MN || tBiaspBias(n)) {
869+
#pragma unroll
870+
for (int i = 0; i < size<0>(tBiassBias); ++i) {
871+
const auto coord = tBiascBias(i, m, n);
872+
const int row = get<0>(coord);
873+
const int col = get<1>(coord);
874+
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
875+
atomicAdd(
876+
gdBias_accum_ptr + row * params.dbias_row_stride + col,
877+
static_cast<ElementAccum>(tBiassBias(i, m, n))
878+
);
879+
}
880+
}
881+
}
882+
}
883+
}
884+
}
885+
}
857886
}
858887

859888
// if (cute::thread0()) { print(tPrP); }
@@ -994,6 +1023,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
9941023
// Advance gBias and gdBias
9951024
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
9961025
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
1026+
if (params.accum_dbias) {
1027+
gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride);
1028+
}
9971029
if (any_active_next) {
9981030
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
9991031
gmem_tiled_copy_Bias,

0 commit comments

Comments
 (0)