Skip to content

Commit b93a5c5

Browse files
committed
Optimizes bias grad: remove atomics, add block reduce
Replaces atomic updates in the bias-gradient path with on-chip accumulation and a post-loop row-sum reduction, reducing contention and improving performance and determinism. Derives the accumulation condition from layout (row stride == 0) and sequence length, drops auxiliary pointers/increments, and adds necessary synchronization to avoid shared-memory races when reusing buffers. Cleans up zeroing and copy ordering and consolidates the final write to global memory.
1 parent 8b72ed7 commit b93a5c5

File tree

1 file changed

+57
-35
lines changed

1 file changed

+57
-35
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
101101
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
102102

103103
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
104+
bool accum_dbias = Has_bias && (params.dbias_row_stride == 0) && (binfo.actual_seqlen_q > 1);
104105

105106
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
106107
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
@@ -159,10 +160,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
159160
Shape<Int<kBlockM>, Int<kBlockN>>{},
160161
make_stride(params.dbias_row_stride, _1{})
161162
);
162-
[[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr;
163-
if constexpr (Has_bias) {
164-
gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias;
165-
}
166163
Tensor gdO = make_tensor(
167164
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
168165
Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -287,8 +284,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
287284
GmemTiledCopydO gmem_tiled_copy_dO;
288285
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
289286
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
290-
typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
291-
auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx);
292287
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
293288
using GmemLayoutAtomdQaccum = std::conditional_t<
294289
!Seq_parallel,
@@ -297,6 +292,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
297292
>;
298293
GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
299294
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
295+
typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
296+
auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx);
300297

301298
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
302299
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@@ -346,6 +343,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
346343

347344
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K)
348345
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K)
346+
[[maybe_unused]] auto acc_dbias = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});
347+
[[maybe_unused]] auto acc_dbias_rowcol = make_tensor(acc_dbias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dbias.layout()));
349348

350349
// Copy Atom retiling
351350
auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
@@ -641,8 +640,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
641640
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
642641
}
643642

644-
clear(acc_dv);
645643
clear(acc_dk);
644+
clear(acc_dv);
645+
if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } }
646646

647647
for (; m_block >= m_block_min; --m_block) {
648648
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
@@ -806,6 +806,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
806806
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
807807
if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
808808
dS(mi, ni) = scaled_ds;
809+
if constexpr (Has_bias) {
810+
if (accum_dbias) {
811+
acc_dbias_rowcol(mi, ni) += scaled_ds;
812+
}
813+
}
809814
}
810815
}
811816
// if (cute::thread0()) { print(dS); }
@@ -852,36 +857,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
852857
__syncthreads();
853858
if constexpr (Has_bias) {
854859
// Write dS to dBias
855-
if (!params.accum_dbias) {
860+
if (!accum_dbias) {
856861
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
857862
gmem_tiled_copy_dBias,
858863
tBiassBias, tdBiasgdBias,
859864
tBiascBias, tBiaspBias,
860865
binfo.actual_seqlen_q - m_block * kBlockM
861866
);
862-
} else {
863-
#pragma unroll
864-
for (int m = 0; m < size<1>(tBiassBias); ++m) {
865-
if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
866-
#pragma unroll
867-
for (int n = 0; n < size<2>(tBiassBias); ++n) {
868-
if (Is_even_MN || tBiaspBias(n)) {
869-
#pragma unroll
870-
for (int i = 0; i < size<0>(tBiassBias); ++i) {
871-
const auto coord = tBiascBias(i, m, n);
872-
const int row = get<0>(coord);
873-
const int col = get<1>(coord);
874-
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
875-
atomicAdd(
876-
gdBias_accum_ptr + row * params.dbias_row_stride + col,
877-
static_cast<ElementAccum>(tBiassBias(i, m, n))
878-
);
879-
}
880-
}
881-
}
882-
}
883-
}
884-
}
885867
}
886868
}
887869

@@ -1023,9 +1005,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
10231005
// Advance gBias and gdBias
10241006
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
10251007
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
1026-
if (params.accum_dbias) {
1027-
gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride);
1028-
}
10291008
if (any_active_next) {
10301009
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
10311010
gmem_tiled_copy_Bias,
@@ -1069,10 +1048,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
10691048

10701049
// Epilogue
10711050

1051+
if constexpr (Has_bias) {
1052+
if (accum_dbias) {
1053+
const int actual_block_n = Is_even_MN ? kBlockN : std::max(0, std::min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN));
1054+
1055+
// Convert acc_dbias from fp32 to fp16
1056+
Tensor tdBiasrdBias = FLASH_NAMESPACE::convert_type<Element>(acc_dbias);
1057+
1058+
// Partition sBias to match the accumulator partitioning
1059+
Tensor tdBiasadBias = smem_thr_copy_Bias.retile_S(tdBiasrdBias); // ((Atom, AtomNum), MMA_M, MMA_N)
1060+
1061+
// We need syncthreads here since we're writing to the same location as sBias.
1062+
// Without syncthreads, some thread might modify the location of sBias while another thread
1063+
// is reading it for dQ gemm, leading to a race condition.
1064+
// If Is_last, there's already a __syncthreads() at the end of the loop.
1065+
if (!Is_last) { __syncthreads(); }
1066+
1067+
cute::copy(smem_tiled_copy_PdS, tdBiasadBias, tdSsdS);
1068+
1069+
__syncthreads();
1070+
for (int col = threadIdx.x; col < kBlockN; col += blockDim.x) {
1071+
if (col < actual_block_n) {
1072+
ElementAccum rowsum = 0.f;
1073+
#pragma unroll
1074+
for (int row = 0; row < kBlockM; ++row) {
1075+
rowsum += static_cast<ElementAccum>(sdS(row, col));
1076+
}
1077+
sdS(0, col) = static_cast<Element>(rowsum);
1078+
}
1079+
}
1080+
__syncthreads();
1081+
1082+
#pragma unroll
1083+
for (int ni = 0; ni < size(tBiaspBias); ++ni) { tBiaspBias(ni) = ni < actual_block_n; }
1084+
// Clear_OOB_K must be false since we don't want to write zeros to gmem
1085+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/false, /*Clear_OOB_MN=*/false>(
1086+
gmem_tiled_copy_dBias,
1087+
tBiassBias, tdBiasgdBias,
1088+
tBiascBias, tBiaspBias,
1089+
/*max_M=*/1
1090+
);
1091+
}
1092+
}
1093+
10721094
#pragma unroll
10731095
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax; }
10741096

1075-
// Convert acc_dv from fp32 to fp16
1097+
// Convert acc_dk, acc_dv from fp32 to fp16
10761098
Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
10771099
Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);
10781100

0 commit comments

Comments
 (0)