Skip to content

Commit 77edcb0

Browse files
committed
Fixes bias tensor initialization in FlashDMAttnFunc to handle None case
1 parent 87ce7cc commit 77edcb0

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ def backward(
287287
*args: Any,
288288
):
289289
q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors
290-
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
290+
dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
291+
dbias = torch.zeros_like(bias) if bias is not None else None
291292

292293
head_size_og = dout.size(3)
293294
dout_padded = dout

0 commit comments

Comments
 (0)