Skip to content

Commit 4e39402

Browse files
committed
forward sort of works with a bunch of hacks
1 parent 202b21e commit 4e39402

File tree

2 files changed

+83
-9
lines changed

2 files changed

+83
-9
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def is_contiguous(x: Tensor):
3434
from importlib.metadata import version
3535

3636
try:
37-
triton_version = version('triton-nightly')
37+
triton_version = version('triton')
3838
except:
3939
print(f'latest triton must be installed. `{INSTALL_COMMAND}` first')
4040
exit()
@@ -216,17 +216,91 @@ def _fwd_kernel(
216216
l_i_new = tl.exp(lse_i - m_ij) + l_ij
217217
lse_i = m_ij + tl.log(l_i_new)
218218

219+
# take care of the selected kv blocks
220+
221+
kv_block_indices_ptrs = (
222+
KV_block_indices +
223+
off_b * stride_kvbl_b +
224+
off_h * stride_kvbl_h +
225+
offs_m * stride_kvbl_m
226+
)
227+
228+
kv_block_mask_ptrs = (
229+
KV_block_mask +
230+
off_b * stride_kvbl_b +
231+
off_h * stride_kvbl_h +
232+
offs_m * stride_kvbl_m
233+
)
234+
235+
for off_sel_kv_block in range(NUM_SEL_KV_BLOCKS):
236+
block_indices = tl.load(kv_block_indices_ptrs + off_sel_kv_block)
237+
block_masks = tl.load(kv_block_mask_ptrs + off_sel_kv_block)
238+
239+
blocks_offs_n = block_indices[:, None] * BLOCK + tl.arange(0, BLOCK)[None, :]
240+
241+
block_k_ptrs = (
242+
K + off_b * stride_kb + off_h * stride_kh + (blocks_offs_n[:, :, None] * stride_kn + offs_d[None, None, :])
243+
)
244+
245+
block_v_ptrs = (
246+
V + off_b * stride_vb + off_h * stride_vh + (blocks_offs_n[:, :, None] * stride_vn + offs_d[None, None, :])
247+
)
248+
249+
# load k of shape (m, n, d), sparsely selected by each query
250+
251+
k_block = tl.load(block_k_ptrs)
252+
253+
# similarities
254+
255+
block_qk = tl.zeros([BLOCK, 16, BLOCK], dtype = tl.float32)
256+
qk = tl.zeros([BLOCK, BLOCK], dtype = tl.float32)
257+
258+
k_block = tl.reshape(k_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
259+
k_block = tl.permute(k_block, (0, 2, 1))
260+
261+
q_expanded = tl.expand_dims(q, 1)
262+
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, 16, BLOCK_HEADDIM))
263+
264+
block_qk = tl.dot(q_expanded, k_block)
265+
qk += tl.sum(block_qk, 1) / 16.
266+
qk += tl.where(block_masks[:, None], 0, float("-inf"))
267+
268+
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
269+
p = tl.exp(qk * softmax_scale - m_ij[:, None])
270+
271+
l_ij = tl.sum(p, 1)
272+
273+
acc_o_scale = tl.exp(m_i - m_ij)
274+
acc_o = acc_o * acc_o_scale[:, None]
275+
276+
v_block = tl.load(block_v_ptrs)
277+
v_block = tl.reshape(v_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
278+
279+
p = p.to(v_block.dtype)
280+
p_expanded = tl.expand_dims(p, 1)
281+
p_expanded = tl.broadcast_to(p_expanded, (BLOCK, 16, BLOCK))
282+
283+
block_acc_o = tl.dot(p_expanded, v_block)
284+
block_acc_o = tl.sum(block_acc_o, 1) / 16.
285+
acc_o += block_acc_o
286+
287+
# -- update statistics
288+
289+
m_i = m_ij
290+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
291+
lse_i = m_ij + tl.log(l_i_new)
292+
219293
# normalize accumulated out
220294

221295
acc_o_scale = tl.exp(m_i - lse_i)
222296
acc_o = acc_o * acc_o_scale[:, None]
223297

224-
# offsets for m and lse
298+
# offsets
225299

226300
start_m = tl.program_id(0)
227301
offs_m = start_m * BLOCK + tl.arange(0, BLOCK)
228302

229-
# write back lse and m
303+
# write back lse
230304

231305
tl.store(lse_ptrs, lse_i)
232306

@@ -253,7 +327,7 @@ def flash_attn_forward(
253327
kv_block_mask,
254328
block_size = 128
255329
):
256-
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
330+
q, k, v, kv_block_indices = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v, kv_block_indices)]
257331

258332
batch, seqlen_q, nheads, dim = q.shape
259333
_, seqlen_k, _, _ = k.shape
@@ -266,7 +340,7 @@ def flash_attn_forward(
266340
assert dim <= 128, "only support head dimensions up to 128"
267341
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
268342
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
269-
assert q.is_cuda and k.is_cuda and v.is_cuda
343+
assert all([t.is_cuda for t in (q, k, v)])
270344

271345
softmax_scale = dim ** -0.5
272346

test_triton_nsa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def regular_attend(
4242
sel_bv = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bv, indices)
4343

4444
q = rearrange(q, 'b (h w) n d -> b h (w n) d', h = q_heads)
45-
bsim = einsum(q, sel_bk, 'b h i d, b h i j d -> b h i j') * scale
45+
bsim = einsum(q, sel_bk, 'b h i d, b h i j d -> b h i j')
4646

4747
bsim = rearrange(bsim, 'b h (w i) (sel j) -> b h w i sel j', sel = num_sel_kv_blocks, i = fine_block_size)
4848

@@ -76,14 +76,14 @@ def regular_attend(
7676

7777
# mock inputs
7878

79-
fine_block_size = 64
79+
fine_block_size = 16
8080

8181
q = torch.randn(1, 2, 512, 64).cuda()
8282
k = torch.randn(1, 2, 512, 64).cuda()
8383
v = torch.randn(1, 2, 512, 64).cuda()
8484

85-
indices = torch.zeros(1, 2, 512, 0).long().cuda()
86-
mask = torch.zeros(1, 2, 512, 0).bool().cuda()
85+
indices = torch.zeros(1, 2, 512, 2).long().cuda()
86+
mask = torch.ones(1, 2, 512, 2).bool().cuda()
8787

8888
# both regular and nsa pathways `r` and `n`
8989

0 commit comments

Comments
 (0)