Skip to content

Commit 0aaaea8

Browse files
committed
some progress with forwards kernel with loading of all grouped query heads
1 parent f533c14 commit 0aaaea8

File tree

2 files changed

+59
-23
lines changed

2 files changed

+59
-23
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,49 +94,72 @@ def forward_kernel(
9494
EVEN_N: tl.constexpr,
9595
EVEN_HEADDIM: tl.constexpr,
9696
BLOCK: tl.constexpr,
97+
QUERY_HEAD_GROUPS: tl.constexpr,
9798
NUM_SEL_KV_BLOCKS: tl.constexpr
9899
):
99100
start_m = tl.program_id(0)
100101
off_hb = tl.program_id(1)
101102
off_b = off_hb // nheads
103+
102104
off_h = off_hb % nheads
103105

106+
offs_qh = off_h * QUERY_HEAD_GROUPS + tl.arange(0, QUERY_HEAD_GROUPS)
107+
104108
offs_m = start_m * BLOCK + tl.arange(0, BLOCK)
105109
offs_n = start_m * BLOCK + tl.arange(0, BLOCK)
106110
offs_d = tl.arange(0, BLOCK_HEADDIM)
107111

108112
q_ptrs = (
109-
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
113+
Q +
114+
off_b * stride_qb +
115+
offs_qh[:, None, None] * stride_qh +
116+
offs_m[None, :, None] * stride_qm +
117+
offs_d[None, None, :]
110118
)
119+
111120
k_ptrs = (
112-
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
121+
K +
122+
off_b * stride_kb +
123+
off_h * stride_kh +
124+
offs_n[:, None] * stride_kn +
125+
offs_d[None, :]
113126
)
127+
114128
v_ptrs = (
115-
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
129+
V +
130+
off_b * stride_vb +
131+
off_h * stride_vh +
132+
offs_n[:, None] * stride_vn +
133+
offs_d[None, :]
116134
)
117135

118136
# maximum
119137

120-
m_i = tl.zeros([BLOCK], dtype = tl.float32) - float("inf")
138+
m_i = tl.zeros([BLOCK * QUERY_HEAD_GROUPS], dtype = tl.float32) - float("inf")
121139

122140
# lse
123141

124-
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
142+
offs_lse_qh = tl.arange(0, QUERY_HEAD_GROUPS)
143+
144+
lse_ptrs = (
145+
Lse +
146+
(off_hb + offs_lse_qh[:, None]) * seqlen_q_rounded +
147+
offs_m[None, :]
148+
)
125149

126-
lse_i = tl.zeros([BLOCK], dtype = tl.float32) - float("inf")
150+
lse_i = tl.zeros([BLOCK * QUERY_HEAD_GROUPS], dtype = tl.float32) - float("inf")
127151

128152
# output
129153

130-
offs_d = tl.arange(0, BLOCK_HEADDIM)
131-
132154
out_ptrs = (
133-
Out
134-
+ off_b * stride_ob
135-
+ off_h * stride_oh
136-
+ (offs_m[:, None] * stride_om + offs_d[None, :])
155+
Out +
156+
off_b * stride_ob +
157+
offs_qh[:, None, None] * stride_oh +
158+
offs_m[None, :, None] * stride_om +
159+
offs_d[None, None, :]
137160
)
138161

139-
acc_o = tl.zeros([BLOCK, BLOCK_HEADDIM], dtype = tl.float32)
162+
acc_o = tl.zeros([QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM], dtype = tl.float32)
140163

141164
# load queries, keys, values
142165

@@ -153,6 +176,8 @@ def forward_kernel(
153176
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
154177
)
155178

179+
q = q.reshape([QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM])
180+
156181
if EVEN_N & EVEN_M:
157182
if EVEN_HEADDIM:
158183
k = tl.load(k_ptrs)
@@ -172,14 +197,18 @@ def forward_kernel(
172197
other=0.0,
173198
)
174199

175-
qk = tl.zeros([BLOCK, BLOCK], dtype=tl.float32)
200+
qk = tl.zeros([QUERY_HEAD_GROUPS * BLOCK, BLOCK], dtype=tl.float32)
176201
qk += tl.dot(q, tl.trans(k))
177202

178203
if not EVEN_N:
179204
qk += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf"))
180205

206+
qk = qk.reshape([QUERY_HEAD_GROUPS, BLOCK, BLOCK])
207+
181208
qk += tl.where(offs_m[:, None] >= offs_n[None, :], 0, float("-inf"))
182209

210+
qk = qk.reshape([QUERY_HEAD_GROUPS * BLOCK, BLOCK])
211+
183212
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
184213
p = tl.exp(qk * softmax_scale - m_ij[:, None])
185214

@@ -303,10 +332,13 @@ def forward_kernel(
303332

304333
# write back lse
305334

335+
lse_i = lse_i.reshape([QUERY_HEAD_GROUPS, BLOCK])
306336
tl.store(lse_ptrs, lse_i)
307337

308338
# write to output
309339

340+
acc_o = acc_o.reshape([QUERY_HEAD_GROUPS, BLOCK, BLOCK_HEADDIM])
341+
310342
if EVEN_M:
311343
if EVEN_HEADDIM:
312344
tl.store(out_ptrs, acc_o)
@@ -331,13 +363,15 @@ def flash_attn_forward(
331363
q, k, v, kv_block_indices = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v, kv_block_indices)]
332364

333365
batch, nheads, seqlen_q, dim, device = *q.shape, q.device
334-
_, _, seqlen_k, _ = k.shape
366+
_, kv_heads, seqlen_k, _ = k.shape
367+
assert divisible_by(nheads, kv_heads)
368+
head_groups = nheads // kv_heads
335369

336370
num_selected_fine_blocks = kv_block_indices.shape[-1]
337371
assert kv_block_indices.shape == kv_block_mask.shape
338372

339-
assert k.shape == (batch, nheads, seqlen_k, dim)
340-
assert v.shape == (batch, nheads, seqlen_k, dim)
373+
assert k.shape == (batch, kv_heads, seqlen_k, dim)
374+
assert v.shape == (batch, kv_heads, seqlen_k, dim)
341375
assert dim <= 128, "only support head dimensions up to 128"
342376
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
343377
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
@@ -353,7 +387,8 @@ def flash_attn_forward(
353387

354388
BLOCK_HEADDIM = max(triton.next_power_of_2(dim), 16)
355389
num_warps = 4 if dim <= 64 else 8
356-
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
390+
391+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * kv_heads) # kv heads here, as grouped query heads all loaded, following the paper
357392

358393
forward_kernel[grid](
359394
q,
@@ -388,6 +423,7 @@ def flash_attn_forward(
388423
seqlen_k // 32,
389424
BLOCK_HEADDIM,
390425
BLOCK = block_size,
426+
QUERY_HEAD_GROUPS = head_groups,
391427
NUM_SEL_KV_BLOCKS = num_selected_fine_blocks,
392428
num_warps = num_warps,
393429
num_stages = 1,
@@ -1090,8 +1126,6 @@ def forward(
10901126
assert divisible_by(q_heads, kv_heads)
10911127
head_groups = q_heads // kv_heads
10921128

1093-
fk, fv, selected_block_indices, fmask = tuple(repeat(t, 'b h ... -> b (h g) ...', g = head_groups) for t in (fk, fv, selected_block_indices, fmask))
1094-
10951129
fq, fk, fv = tuple(t.half() for t in (fq, fk, fv))
10961130

10971131
out, lse = flash_attn_forward(
@@ -1101,6 +1135,8 @@ def forward(
11011135
block_size = block_size
11021136
)
11031137

1138+
fk, fv, selected_block_indices, fmask = tuple(repeat(t, 'b h ... -> b (h g) ...', g = head_groups) for t in (fk, fv, selected_block_indices, fmask))
1139+
11041140
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
11051141

11061142
ctx._saved_variables = (

test_triton_nsa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def regular_attend(
7575
if has_sel_kv_blocks:
7676
out = einsum(attn, v, 'b h g w i j, b h w i j d -> b h g w i d')
7777
else:
78-
out = einsum(attn, v, 'b h g w i j, b h j d -> b h g w i d')
78+
out = einsum(attn, v, 'b h g w i j, b h w j d -> b h g w i d')
7979

8080
return rearrange(out, 'b h g w n d -> b (h g) (w n) d')
8181

@@ -87,8 +87,8 @@ def regular_attend(
8787
k = torch.randn(1, 2, 512, 64).cuda()
8888
v = torch.randn(1, 2, 512, 64).cuda()
8989

90-
indices = torch.zeros(1, 2, 512, 1).long().cuda()
91-
mask = torch.ones(1, 2, 512, 1).bool().cuda()
90+
indices = torch.zeros(1, 2, 512, 0).long().cuda()
91+
mask = torch.ones(1, 2, 512, 0).bool().cuda()
9292

9393
# both regular and nsa pathways `r` and `n`
9494

0 commit comments

Comments
 (0)