Skip to content

Commit 78bb93d

Browse files
authored
Merge pull request #193 from SmallDoges/fix-189
Enhance bias gradient accumulation in backward pass
2 parents 7c4b102 + 7f15118 commit 78bb93d

File tree

3 files changed

+55
-14
lines changed

3 files changed

+55
-14
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ void set_params_dgrad(
190190
const float softcap,
191191
bool has_mask,
192192
bool has_bias,
193+
bool accum_dbias,
193194
bool deterministic,
194195
const bool unpadded_lse
195196
) {
@@ -245,6 +246,8 @@ void set_params_dgrad(
245246
// Softmax sum
246247
params.dsoftmax_sum = dsoftmax_sum_d;
247248

249+
params.accum_dbias = accum_dbias;
250+
248251
params.deterministic = deterministic;
249252
}
250253

@@ -977,12 +980,13 @@ mha_bwd(
977980
? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts)
978981
: dv;
979982
dbias_expanded = has_bias
980-
? (
981-
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
982-
? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
983-
: dbias
984-
)
983+
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
984+
? (seqlen_q_bias == 1)
985+
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
986+
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
987+
: dbias
985988
: torch::empty({0}, opts);
989+
bool accum_dbias = has_bias && (seqlen_q_bias == 1 && seqlen_q != 1);
986990

987991
Flash_bwd_params params;
988992

@@ -1009,6 +1013,7 @@ mha_bwd(
10091013
softcap,
10101014
has_mask,
10111015
has_bias,
1016+
accum_dbias,
10121017
deterministic,
10131018
/*unpadded_lse*/false
10141019
);
@@ -1036,9 +1041,10 @@ mha_bwd(
10361041
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
10371042
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
10381043
} else {
1039-
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
1040-
if (seqlen_q_bias == 1) {
1041-
dbias_expanded = at::sum(dbias_expanded, {2}, true);
1044+
if (accum_dbias) {
1045+
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
1046+
} else {
1047+
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
10421048
}
10431049
if (batch_size_bias == 1) {
10441050
dbias_expanded = at::sum(dbias_expanded, {0}, true);
@@ -1238,6 +1244,7 @@ mha_varlen_bwd(
12381244
softcap,
12391245
has_mask,
12401246
has_bias,
1247+
/*accum_dbias*/false,
12411248
deterministic,
12421249
/*unpadded_lse*/true
12431250
);

csrc/flash_dmattn/src/flash.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ struct Flash_bwd_params : public Flash_fwd_params {
195195

196196
bool deterministic;
197197
index_t dq_accum_split_stride;
198+
199+
bool accum_dbias;
198200
};
199201

200202
////////////////////////////////////////////////////////////////////////////////////////////////////

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
159159
Shape<Int<kBlockM>, Int<kBlockN>>{},
160160
make_stride(params.dbias_row_stride, _1{})
161161
);
162+
[[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr;
163+
if constexpr (Has_bias) {
164+
gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias;
165+
}
162166
Tensor gdO = make_tensor(
163167
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
164168
Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -848,12 +852,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
848852
__syncthreads();
849853
if constexpr (Has_bias) {
850854
// Write dS to dBias
851-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
852-
gmem_tiled_copy_dBias,
853-
tBiassBias, tdBiasgdBias,
854-
tBiascBias, tBiaspBias,
855-
binfo.actual_seqlen_q - m_block * kBlockM
856-
);
855+
if (!params.accum_dbias) {
856+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
857+
gmem_tiled_copy_dBias,
858+
tBiassBias, tdBiasgdBias,
859+
tBiascBias, tBiaspBias,
860+
binfo.actual_seqlen_q - m_block * kBlockM
861+
);
862+
} else {
863+
#pragma unroll
864+
for (int m = 0; m < size<1>(tBiassBias); ++m) {
865+
if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
866+
#pragma unroll
867+
for (int n = 0; n < size<2>(tBiassBias); ++n) {
868+
if (Is_even_MN || tBiaspBias(n)) {
869+
#pragma unroll
870+
for (int i = 0; i < size<0>(tBiassBias); ++i) {
871+
const auto coord = tBiascBias(i, m, n);
872+
const int row = get<0>(coord);
873+
const int col = get<1>(coord);
874+
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
875+
atomicAdd(
876+
gdBias_accum_ptr + row * params.dbias_row_stride + col,
877+
static_cast<ElementAccum>(tBiassBias(i, m, n))
878+
);
879+
}
880+
}
881+
}
882+
}
883+
}
884+
}
885+
}
857886
}
858887

859888
// if (cute::thread0()) { print(tPrP); }
@@ -994,6 +1023,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
9941023
// Advance gBias and gdBias
9951024
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
9961025
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
1026+
if (params.accum_dbias) {
1027+
gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride);
1028+
}
9971029
if (any_active_next) {
9981030
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
9991031
gmem_tiled_copy_Bias,

0 commit comments

Comments
 (0)