Skip to content

Commit d5a76e1

Browse files
committed
gqa forward triton kernel complete
1 parent 0aaaea8 commit d5a76e1

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def forward_kernel(
9595
EVEN_HEADDIM: tl.constexpr,
9696
BLOCK: tl.constexpr,
9797
QUERY_HEAD_GROUPS: tl.constexpr,
98+
QUERY_EXPAND_DIM: tl.constexpr,
9899
NUM_SEL_KV_BLOCKS: tl.constexpr
99100
):
100101
start_m = tl.program_id(0)
@@ -261,6 +262,12 @@ def forward_kernel(
261262
offs_m * stride_kvbl_m
262263
)
263264

265+
q = q.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM)
266+
q = q.permute((1, 0, 2))
267+
q = tl.expand_dims(q, 2)
268+
q = tl.broadcast_to(q, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM))
269+
q = q.reshape(BLOCK, 16, BLOCK_HEADDIM)
270+
264271
for off_sel_kv_block in range(NUM_SEL_KV_BLOCKS):
265272
block_indices = tl.load(kv_block_indices_ptrs + off_sel_kv_block)
266273
block_masks = tl.load(kv_block_mask_ptrs + off_sel_kv_block)
@@ -282,18 +289,21 @@ def forward_kernel(
282289
# similarities
283290

284291
block_qk = tl.zeros([BLOCK, 16, BLOCK], dtype = tl.float32)
285-
qk = tl.zeros([BLOCK, BLOCK], dtype = tl.float32)
292+
qk = tl.zeros([QUERY_HEAD_GROUPS, BLOCK, BLOCK], dtype = tl.float32)
286293

287-
k_block = tl.reshape(k_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
288-
k_block = tl.permute(k_block, (0, 2, 1))
294+
k_block = k_block.reshape(BLOCK, BLOCK, BLOCK_HEADDIM)
295+
k_block = k_block.permute(0, 2, 1)
289296

290-
q_expanded = tl.expand_dims(q, 1)
291-
q_expanded = tl.broadcast_to(q_expanded, (BLOCK, 16, BLOCK_HEADDIM))
297+
block_qk = tl.dot(q, k_block)
298+
block_qk = block_qk.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK)
299+
block_qk = tl.sum(block_qk, 2) / QUERY_EXPAND_DIM
300+
block_qk = block_qk.permute(1, 0, 2)
292301

293-
block_qk = tl.dot(q_expanded, k_block)
294-
qk += tl.sum(block_qk, 1) / 16.
302+
qk += block_qk
295303
qk += tl.where(block_masks[:, None], 0, float("-inf"))
296304

305+
qk = qk.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
306+
297307
# attention
298308

299309
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
@@ -312,11 +322,18 @@ def forward_kernel(
312322
v_block = tl.reshape(v_block, (BLOCK, BLOCK, BLOCK_HEADDIM))
313323

314324
p = p.to(v_block.dtype)
315-
p_expanded = tl.expand_dims(p, 1)
316-
p_expanded = tl.broadcast_to(p_expanded, (BLOCK, 16, BLOCK))
325+
p_expanded = p.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK)
326+
p_expanded = p_expanded.permute(1, 0, 2)
327+
p_expanded = tl.expand_dims(p_expanded, 2)
328+
p_expanded = tl.broadcast_to(p_expanded, (BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK))
329+
p_expanded = p_expanded.reshape(BLOCK, 16, BLOCK)
317330

318331
block_acc_o = tl.dot(p_expanded, v_block)
319-
block_acc_o = tl.sum(block_acc_o, 1) / 16.
332+
block_acc_o = block_acc_o.reshape(BLOCK, QUERY_HEAD_GROUPS, QUERY_EXPAND_DIM, BLOCK_HEADDIM)
333+
block_acc_o = tl.sum(block_acc_o, 2) / QUERY_EXPAND_DIM
334+
block_acc_o = block_acc_o.permute(1, 0, 2)
335+
block_acc_o = block_acc_o.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM)
336+
320337
acc_o += block_acc_o
321338

322339
# -- update statistics
@@ -352,7 +369,7 @@ def forward_kernel(
352369
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
353370
)
354371

355-
def flash_attn_forward(
372+
def native_sparse_attn_forward(
356373
q,
357374
k,
358375
v,
@@ -424,6 +441,7 @@ def flash_attn_forward(
424441
BLOCK_HEADDIM,
425442
BLOCK = block_size,
426443
QUERY_HEAD_GROUPS = head_groups,
444+
QUERY_EXPAND_DIM = 16 // head_groups,
427445
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
428446
num_warps = num_warps,
429447
num_stages = 1,
@@ -978,7 +996,7 @@ def backward_kernel(
978996
NUM_SEL_KV_BLOCKS = NUM_SEL_KV_BLOCKS
979997
)
980998

981-
def flash_attn_backward(
999+
def native_sparse_attn_backward(
9821000
do,
9831001
q, k, v,
9841002
kv_block_indices,
@@ -1128,7 +1146,7 @@ def forward(
11281146

11291147
fq, fk, fv = tuple(t.half() for t in (fq, fk, fv))
11301148

1131-
out, lse = flash_attn_forward(
1149+
out, lse = native_sparse_attn_forward(
11321150
fq, fk, fv,
11331151
selected_block_indices,
11341152
fmask,
@@ -1162,7 +1180,7 @@ def backward(self, ctx, do):
11621180
dk = torch.zeros(k.shape, dtype = torch.float32, device = device)
11631181
dv = torch.zeros(v.shape, dtype = torch.float32, device = device)
11641182

1165-
flash_attn_backward(
1183+
native_sparse_attn_backward(
11661184
do, q, k, v,
11671185
sel_block_indices, mask,
11681186
out, lse, dq, dk, dv,

test_triton_nsa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def regular_attend(
8787
k = torch.randn(1, 2, 512, 64).cuda()
8888
v = torch.randn(1, 2, 512, 64).cuda()
8989

90-
indices = torch.zeros(1, 2, 512, 0).long().cuda()
91-
mask = torch.ones(1, 2, 512, 0).bool().cuda()
90+
indices = torch.zeros(1, 2, 512, 1).long().cuda()
91+
mask = torch.ones(1, 2, 512, 1).bool().cuda()
9292

9393
# both regular and nsa pathways `r` and `n`
9494

0 commit comments

Comments
 (0)