Skip to content

Commit 4dd087d

Browse files
authored
Merge pull request #191 from SmallDoges:fix-189
Refactor bias initialization and enhance bias computation in FlashDMAttnFunc
2 parents 464baf7 + 502a1a4 commit 4dd087d

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -979,13 +979,10 @@ mha_bwd(
979979
dbias_expanded = has_bias
980980
? (
981981
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
982-
? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
982+
? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
983983
: dbias
984984
)
985985
: torch::empty({0}, opts);
986-
if (has_bias) {
987-
dbias_expanded.zero_();
988-
}
989986

990987
Flash_bwd_params params;
991988

@@ -1050,7 +1047,7 @@ mha_bwd(
10501047
}
10511048
}
10521049

1053-
return { dq, dk, dv, dbias, softmax_d };
1050+
return {dq, dk, dv, dbias, softmax_d};
10541051
}
10551052

10561053
std::vector<at::Tensor>

flash_dmattn/flash_dmattn_interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def forward(
409409
deterministic = False
410410
if return_softmax is None:
411411
return_softmax = False
412+
seqlen_k_bias_og = bias.shape[-1] if bias is not None else 0
412413

413414
# Padding to multiple of 8 for 16-bit memory allocations
414415
head_size_og = q.size(3)
@@ -446,6 +447,7 @@ def forward(
446447
ctx.is_causal = is_causal
447448
ctx.softcap = softcap
448449
ctx.deterministic = deterministic
450+
ctx.seqlen_k_bias_og = seqlen_k_bias_og
449451

450452
out = out_padded[..., :head_size_og]
451453

@@ -491,7 +493,7 @@ def backward(
491493
dv = dv[..., : dout.shape[-1]]
492494

493495
if dbias is not None:
494-
dbias = dbias[..., : k.shape[1]]
496+
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]]
495497

496498
return dq, dk, dv, None, dbias, None, None, None, None, None, None
497499

0 commit comments

Comments
 (0)