[BUG FIX] Correct causal mask handling for longer KV pairs#213
[BUG FIX] Correct causal mask handling for longer KV pairs#213LoserCheems merged 1 commit intomainfrom
Conversation
Ensures the forward loop bounds respect shorter queries so causal attention ignores future tokens when kv > q Adjusts the causal mask offset to align with kv-q gaps and prevent invalid attention weights
There was a problem hiding this comment.
Pull request overview
This PR attempts to fix causal mask handling when the key-value sequence length exceeds the query sequence length, a common scenario in KV caching during inference. The changes modify the forward kernel's loop bounds and causal mask offset calculation.
Key Changes:
- Modified the
end_nloop bound calculation to account for longer KV sequences - Added offset
(seqlen_k - seqlen_q)to the causal mask condition to properly align query positions with the key timeline
Critical Issue Identified:
The logic in line 168 contains a bug that will break non-causal attention when seqlen_k > seqlen_q, as it incorrectly applies causal-like loop bounds in this case.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # Loop over k, v and update accumulator | ||
| end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) | ||
| end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) |
There was a problem hiding this comment.
The condition not IS_CAUSAL and seqlen_k <= seqlen_q is logically incorrect. This causes causal-like behavior to be applied when IS_CAUSAL is False and seqlen_k > seqlen_q, which breaks non-causal attention with longer KV sequences.
When IS_CAUSAL is False, end_n should always be seqlen_k regardless of the relationship between seqlen_k and seqlen_q. The condition should be:
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M + (seqlen_k - seqlen_q), seqlen_k)The offset (seqlen_k - seqlen_q) should only be applied to the causal case, not used to determine whether to apply causal behavior.
| end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) | |
| end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M + (seqlen_k - seqlen_q), seqlen_k) |
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist