diff --git a/flash_sparse_attn/flash_sparse_attn_triton.py b/flash_sparse_attn/flash_sparse_attn_triton.py index eefbf763..c3b14418 100644 --- a/flash_sparse_attn/flash_sparse_attn_triton.py +++ b/flash_sparse_attn/flash_sparse_attn_triton.py @@ -165,7 +165,7 @@ def _fwd_kernel( q = (q * softmax_scale).to(q.dtype) # Loop over k, v and update accumulator - end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -231,7 +231,7 @@ def _fwd_kernel( if not EVEN_N: # Need to mask out otherwise the softmax is wrong acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if IS_CAUSAL: - acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + acc_s += tl.where(offs_m[:, None] + (seqlen_k - seqlen_q) >= (start_n + offs_n)[None, :], 0, float("-inf")) if HAS_MASK: acc_s += tl.where(mask, 0, float("-inf"))