Skip to content

Commit 7535784

Browse files
committed
Removes accum_dbias; fixes zeros dtype
Simplifies bias-gradient handling by deriving accumulation from the bias sequence-length condition, removing the redundant parameter and related plumbing. Aligns zero-init of bias buffers with provided tensor options (no forced float), preventing mixed-precision dtype mismatches and improving correctness for MQA/GQA bias shapes. Streamlines the backward API with no intended behavior changes beyond dtype fix.
1 parent b93a5c5 commit 7535784

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ void set_params_dgrad(
190190
const float softcap,
191191
bool has_mask,
192192
bool has_bias,
193-
bool accum_dbias,
194193
bool deterministic,
195194
const bool unpadded_lse
196195
) {
@@ -246,8 +245,6 @@ void set_params_dgrad(
246245
// Softmax sum
247246
params.dsoftmax_sum = dsoftmax_sum_d;
248247

249-
params.accum_dbias = accum_dbias;
250-
251248
params.deterministic = deterministic;
252249
}
253250

@@ -982,11 +979,10 @@ mha_bwd(
982979
dbias_expanded = has_bias
983980
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
984981
? (seqlen_q_bias == 1)
985-
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
982+
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts)
986983
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
987984
: dbias
988985
: torch::empty({0}, opts);
989-
bool accum_dbias = has_bias && (seqlen_q_bias == 1 && seqlen_q != 1);
990986

991987
Flash_bwd_params params;
992988

@@ -1013,7 +1009,6 @@ mha_bwd(
10131009
softcap,
10141010
has_mask,
10151011
has_bias,
1016-
accum_dbias,
10171012
deterministic,
10181013
/*unpadded_lse*/false
10191014
);
@@ -1041,7 +1036,7 @@ mha_bwd(
10411036
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
10421037
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
10431038
} else {
1044-
if (accum_dbias) {
1039+
if (seqlen_q_bias == 1) {
10451040
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
10461041
} else {
10471042
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
@@ -1244,7 +1239,6 @@ mha_varlen_bwd(
12441239
softcap,
12451240
has_mask,
12461241
has_bias,
1247-
/*accum_dbias*/false,
12481242
deterministic,
12491243
/*unpadded_lse*/true
12501244
);

0 commit comments

Comments
 (0)