Skip to content

Commit 9181f70

Browse files
committed
Fixes dbias accumulation for broadcasted bias
Improves backward handling when bias is broadcast across sequence or batch by allocating correctly shaped scratch buffers and adjusting reduction paths. Adds a kernel parameter to accumulate along sequence for S=1 bias, and uses fp32 buffers for numerically stable accumulation. Corrects the previous over-eager scratch allocation on batch-size mismatch to only trigger for shared (B=1) or head-grouped cases, aligning with broadcasting semantics (incl. MQA/GQA). Leaves the variable-length path unchanged (no accumulation). Results in correct dbias reductions and gradients for broadcasted bias with better numerical stability.
1 parent e5b045d commit 9181f70

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ void set_params_dgrad(
190190
const float softcap,
191191
bool has_mask,
192192
bool has_bias,
193+
bool accum_dbias,
193194
bool deterministic,
194195
const bool unpadded_lse
195196
) {
@@ -245,6 +246,8 @@ void set_params_dgrad(
245246
// Softmax sum
246247
params.dsoftmax_sum = dsoftmax_sum_d;
247248

249+
params.accum_dbias = accum_dbias;
250+
248251
params.deterministic = deterministic;
249252
}
250253

@@ -977,12 +980,13 @@ mha_bwd(
977980
? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts)
978981
: dv;
979982
dbias_expanded = has_bias
980-
? (
981-
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
982-
? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
983-
: dbias
984-
)
983+
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
984+
? (seqlen_q_bias == 1)
985+
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
986+
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
987+
: dbias
985988
: torch::empty({0}, opts);
989+
bool accum_dbias = has_bias && seqlen_q_bias != seqlen_q && seqlen_q_bias == 1;
986990

987991
Flash_bwd_params params;
988992

@@ -1009,6 +1013,7 @@ mha_bwd(
10091013
softcap,
10101014
has_mask,
10111015
has_bias,
1016+
accum_dbias,
10121017
deterministic,
10131018
/*unpadded_lse*/false
10141019
);
@@ -1036,9 +1041,10 @@ mha_bwd(
10361041
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
10371042
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
10381043
} else {
1039-
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
1040-
if (seqlen_q_bias == 1) {
1041-
dbias_expanded = at::sum(dbias_expanded, {2}, true);
1044+
if (accum_dbias) {
1045+
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
1046+
} else {
1047+
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
10421048
}
10431049
if (batch_size_bias == 1) {
10441050
dbias_expanded = at::sum(dbias_expanded, {0}, true);
@@ -1238,6 +1244,7 @@ mha_varlen_bwd(
12381244
softcap,
12391245
has_mask,
12401246
has_bias,
1247+
/*accum_dbias*/false,
12411248
deterministic,
12421249
/*unpadded_lse*/true
12431250
);

0 commit comments

Comments
 (0)