Skip to content

Commit 730f40e

Browse files
committed
Fix accumulator init with optional bias
Initializes the score accumulator within the bias branch and zero-initializes otherwise. Prevents referencing an undefined bias when disabled and improves compilation stability in both forward and backward kernels.
1 parent 4291d01 commit 730f40e

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

flash_dmattn/flash_dmattn_triton.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,11 @@ def _fwd_kernel(
219219
& ((start_n + offs_n)[None, :] < seqlen_k),
220220
other=0.0,
221221
).to(tl.float32)
222+
acc_s = bias
223+
else:
224+
acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
222225

223226
# Compute acc_s
224-
acc_s = bias if HAS_BIAS else tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
225227
acc_s += tl.dot(q, tl.trans(k))
226228

227229
# Apply masks
@@ -507,9 +509,11 @@ def _bwd_kernel_one_col_block(
507509
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
508510
other=0.0,
509511
).to(tl.float32)
512+
acc_s = bias
513+
else:
514+
acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
510515

511516
# Compute acc_s
512-
acc_s = bias if HAS_BIAS else tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
513517
acc_s += tl.dot(q, tl.trans(k))
514518

515519
# Apply masks

0 commit comments

Comments
 (0)