Skip to content

Commit 8628b12

Browse files
committed
more prep for parallel sliding + fine
1 parent 26a8892 commit 8628b12

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def forward_kernel_causal_and_sparse(
106106
QUERY_HEAD_GROUPS: tl.constexpr,
107107
QUERY_EXPAND_DIM: tl.constexpr,
108108
NUM_SEL_KV_BLOCKS: tl.constexpr,
109-
INCLUDE_BLOCK_CAUSAL: tl.constexpr
109+
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
110+
SLIDING: tl.constexpr
110111
):
111112
start_m = tl.program_id(0)
112113
off_hb = tl.program_id(1)
@@ -117,7 +118,6 @@ def forward_kernel_causal_and_sparse(
117118
offs_qh = off_h * QUERY_HEAD_GROUPS + tl.arange(0, QUERY_HEAD_GROUPS)
118119

119120
offs_m = start_m * BLOCK + tl.arange(0, BLOCK)
120-
offs_n = start_m * BLOCK + tl.arange(0, BLOCK)
121121
offs_d = tl.arange(0, BLOCK_HEADDIM)
122122

123123
q_ptrs = (
@@ -181,6 +181,9 @@ def forward_kernel_causal_and_sparse(
181181
)
182182

183183
if INCLUDE_BLOCK_CAUSAL:
184+
185+
offs_n = start_m * BLOCK + tl.arange(0, BLOCK)
186+
184187
k_ptrs = (
185188
K +
186189
off_b * stride_kb +
@@ -225,11 +228,21 @@ def forward_kernel_causal_and_sparse(
225228
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
226229

227230
if not EVEN_N:
228-
qk += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf"))
231+
within_range_mask = offs_n[None, :] < seqlen_k
232+
233+
if SLIDING:
234+
within_range_mask &= offs_n[None, :] >= 0.
235+
236+
qk += tl.where(within_range_mask, 0, float("-inf"))
229237

230238
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
231239

232-
qk += tl.where(offs_m[:, None, None] >= offs_n[None, None, :], 0, float("-inf"))
240+
causal_mask = offs_m[:, None, None] >= offs_n[None, None, :]
241+
242+
if SLIDING:
243+
causal_mask &= (offs_n[None, None, :] - offs_m[:, None, None]) <= BLOCK
244+
245+
qk += tl.where(causal_mask, 0, float("-inf"))
233246

234247
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
235248
p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
@@ -456,7 +469,8 @@ def forward_kernel(
456469
QUERY_HEAD_GROUPS: tl.constexpr,
457470
QUERY_EXPAND_DIM: tl.constexpr,
458471
NUM_SEL_KV_BLOCKS: tl.constexpr,
459-
INCLUDE_BLOCK_CAUSAL: tl.constexpr
472+
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
473+
SLIDING: tl.constexpr
460474
):
461475
forward_kernel_causal_and_sparse(
462476
Q,
@@ -498,7 +512,8 @@ def forward_kernel(
498512
QUERY_HEAD_GROUPS,
499513
QUERY_EXPAND_DIM,
500514
NUM_SEL_KV_BLOCKS,
501-
INCLUDE_BLOCK_CAUSAL
515+
INCLUDE_BLOCK_CAUSAL,
516+
SLIDING
502517
)
503518

504519
def native_sparse_attn_forward(
@@ -578,6 +593,7 @@ def native_sparse_attn_forward(
578593
QUERY_EXPAND_DIM = 16 // head_groups,
579594
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
580595
INCLUDE_BLOCK_CAUSAL = include_block_causal,
596+
SLIDING = False,
581597
num_warps = num_warps,
582598
num_stages = 1,
583599
)
@@ -626,8 +642,8 @@ def backward_preprocess_do_o_dot(
626642
+ off_h * stride_doh
627643
+ offs_m[:, None] * stride_dom
628644
+ offs_d[None, :],
629-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
630-
other=0.0,
645+
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
646+
other = 0.0,
631647
).to(tl.float32)
632648

633649
delta = tl.sum(o * do, axis=1)

0 commit comments

Comments
 (0)