Skip to content

Commit 8effe3c

Browse files
committed
Refactors mha_bwd to use torch::zeros for bias initialization and removes unnecessary zeroing of dbias_expanded
1 parent 08392c8 commit 8effe3c

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -979,13 +979,10 @@ mha_bwd(
979979
dbias_expanded = has_bias
980980
? (
981981
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
982-
? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
982+
? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
983983
: dbias
984984
)
985985
: torch::empty({0}, opts);
986-
if (has_bias) {
987-
dbias_expanded.zero_();
988-
}
989986

990987
Flash_bwd_params params;
991988

@@ -1050,7 +1047,7 @@ mha_bwd(
10501047
}
10511048
}
10521049

1053-
return { dq, dk, dv, dbias, softmax_d };
1050+
return {dq, dk, dv, dbias, softmax_d};
10541051
}
10551052

10561053
std::vector<at::Tensor>

0 commit comments

Comments
 (0)