Skip to content

Commit ee3745c

Browse files
committed
start making way through recompute
1 parent 9077c5f commit ee3745c

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,20 @@ def forward_kernel(
265265
qk += tl.sum(block_qk, 1) / 16.
266266
qk += tl.where(block_masks[:, None], 0, float("-inf"))
267267

268+
# attention
269+
268270
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
269271
p = tl.exp(qk * softmax_scale - m_ij[:, None])
270272

271273
l_ij = tl.sum(p, 1)
272274

275+
# renormalize the running output
276+
273277
acc_o_scale = tl.exp(m_i - m_ij)
274278
acc_o = acc_o * acc_o_scale[:, None]
275279

280+
# aggregate values
281+
276282
v_block = tl.load(block_v_ptrs)
277283
v_block = tl.reshape(v_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
278284

@@ -748,6 +754,17 @@ def backward_kernel_one_col_block(
748754
block_k = tl.load(block_k_ptrs)
749755
block_v = tl.load(block_v_ptrs)
750756

757+
q_expanded = tl.expand_dims(q, 1)
758+
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, 16, BLOCK_HEADDIM))
759+
760+
block_k = tl.permute(block_k, (0, 2, 1))
761+
block_qk = tl.dot(q_expanded, block_k)
762+
763+
qk = tl.sum(block_qk, 1) / 16.
764+
qk += tl.where(block_masks[:, None], 0, float("-inf"))
765+
766+
p = tl.exp(qk * softmax_scale - lse_i[:, None])
767+
751768
# # increment pointers
752769
# dq_ptrs += BLOCK * stride_dqm
753770
# q_ptrs += BLOCK * stride_qm

0 commit comments

Comments
 (0)