Skip to content

Commit 12f9eda

Browse files
committed
Fix offset bugs for compressed kv
1 parent f56759d commit 12f9eda

File tree

1 file changed

+41
-24
lines changed

1 file changed

+41
-24
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import triton.language.core as core
1313
from einops import rearrange
1414

15-
from fla.ops.common.utils import (prepare_chunk_indices, prepare_lens,
16-
prepare_token_indices)
15+
from fla.ops.common.utils import (prepare_chunk_indices, prepare_chunk_offsets,
16+
prepare_lens, prepare_token_indices)
1717
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
1818
from native_sparse_attention.ops.utils import _bitonic_merge
1919

@@ -48,6 +48,7 @@ def parallel_nsa_compression_fwd_kernel(
4848
scale,
4949
offsets,
5050
token_indices,
51+
chunk_offsets,
5152
T,
5253
H: tl.constexpr,
5354
HQ: tl.constexpr,
@@ -67,8 +68,10 @@ def parallel_nsa_compression_fwd_kernel(
6768
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
6869
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
6970
T = eos - bos
71+
boc = tl.load(chunk_offsets + i_n).to(tl.int32)
7072
else:
7173
bos, eos = i_b * T, i_b * T + T
74+
boc = i_b * tl.cdiv(T, BS)
7275

7376
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
7477

@@ -94,8 +97,8 @@ def parallel_nsa_compression_fwd_kernel(
9497
for i_c in range(0, NC, BC):
9598
o_c = i_c + tl.arange(0, BC)
9699

97-
p_k = tl.make_block_ptr(k + (tl.cdiv(bos, BS) * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
98-
p_v = tl.make_block_ptr(v + (tl.cdiv(bos, BS) * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0))
100+
p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
101+
p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0))
99102
# [BK, BC]
100103
b_k = tl.load(p_k, boundary_check=(0, 1))
101104
# [BC, BV]
@@ -150,6 +153,7 @@ def parallel_nsa_compression_bwd_kernel_dq(
150153
scale,
151154
offsets,
152155
token_indices,
156+
chunk_offsets,
153157
T,
154158
B: tl.constexpr,
155159
H: tl.constexpr,
@@ -170,8 +174,10 @@ def parallel_nsa_compression_bwd_kernel_dq(
170174
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
171175
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
172176
T = eos - bos
177+
boc = tl.load(chunk_offsets + i_n).to(tl.int32)
173178
else:
174179
bos, eos = i_b * T, i_b * T + T
180+
boc = i_b * tl.cdiv(T, BS)
175181

176182
q += (bos + i_t) * HQ*K
177183
do += (bos + i_t) * HQ*V
@@ -206,8 +212,8 @@ def parallel_nsa_compression_bwd_kernel_dq(
206212
b_dq = tl.zeros([G, BK], dtype=tl.float32)
207213
for i_c in range(0, NC, BC):
208214
o_c = i_c + tl.arange(0, BC)
209-
p_k = tl.make_block_ptr(k + (tl.cdiv(bos, BS) * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
210-
p_v = tl.make_block_ptr(v + (tl.cdiv(bos, BS) * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1))
215+
p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
216+
p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1))
211217
# [BK, BC]
212218
b_k = tl.load(p_k, boundary_check=(0, 1))
213219
# [BV, BC]
@@ -249,6 +255,7 @@ def parallel_nsa_compression_bwd_kernel_dkv(
249255
dv,
250256
offsets,
251257
chunk_indices,
258+
chunk_offsets,
252259
scale,
253260
T,
254261
B: tl.constexpr,
@@ -270,18 +277,18 @@ def parallel_nsa_compression_bwd_kernel_dkv(
270277
i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32)
271278
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
272279
T = eos - bos
280+
boc = tl.load(chunk_offsets + i_n).to(tl.int32)
273281
else:
274282
bos, eos = i_b * T, i_b * T + T
283+
boc = i_b * tl.cdiv(T, BS)
275284

276285
# the number of compression representations in total
277286
TC = tl.cdiv(T, BS)
278287

279-
p_k = tl.make_block_ptr(k + (tl.cdiv(bos, BS) * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
280-
p_v = tl.make_block_ptr(v + (tl.cdiv(bos, BS) * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
281-
p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + tl.cdiv(bos, BS) * H + i_h) * K,
282-
(TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
283-
p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + tl.cdiv(bos, BS) * H + i_h) * V,
284-
(TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
288+
p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
289+
p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
290+
p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
291+
p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
285292

286293
# [BC, BK]
287294
b_k = tl.load(p_k, boundary_check=(0, 1))
@@ -344,6 +351,7 @@ def parallel_nsa_kernel_topk(
344351
block_indices,
345352
offsets,
346353
token_indices,
354+
chunk_offsets,
347355
T,
348356
H: tl.constexpr,
349357
HQ: tl.constexpr,
@@ -363,8 +371,10 @@ def parallel_nsa_kernel_topk(
363371
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
364372
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
365373
T = eos - bos
374+
boc = tl.load(chunk_offsets + i_n).to(tl.int32)
366375
else:
367376
bos, eos = i_b * T, i_b * T + T
377+
boc = i_b * tl.cdiv(T, BS)
368378

369379
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
370380

@@ -382,7 +392,7 @@ def parallel_nsa_kernel_topk(
382392
# 1. lse computation
383393
################################
384394
if lse is not None:
385-
b_lse = tl.load(lse + (tl.cdiv(bos, BS) + i_t) * HQ + i_h * G + tl.arange(0, G))
395+
b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G))
386396
else:
387397
# max scores for the current block
388398
b_m = tl.full([G], float('-inf'), dtype=tl.float32)
@@ -391,7 +401,7 @@ def parallel_nsa_kernel_topk(
391401
for i_c in range(0, NC, BC):
392402
o_c = i_c + tl.arange(0, BC)
393403

394-
p_k = tl.make_block_ptr(k + (tl.cdiv(bos, BS) * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
404+
p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
395405
# [BK, BC]
396406
b_k = tl.load(p_k, boundary_check=(0, 1))
397407

@@ -423,7 +433,7 @@ def parallel_nsa_kernel_topk(
423433
for i_c in range(0, i_t // BS + 1, BC):
424434
o_c = i_c + tl.arange(0, BC)
425435

426-
p_k = tl.make_block_ptr(k + (tl.cdiv(bos, BS) * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
436+
p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
427437
# [BK, BC]
428438
b_k = tl.load(p_k, boundary_check=(0, 1))
429439
# [G, BC]
@@ -842,6 +852,8 @@ def parallel_nsa_compression_fwd(
842852
NV = triton.cdiv(V, BV)
843853
assert NK == 1, "The key dimension can not be larger than 256"
844854

855+
chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None
856+
845857
grid = (T, NV, B * H)
846858
o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
847859
lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
@@ -855,6 +867,7 @@ def parallel_nsa_compression_fwd(
855867
scale=scale,
856868
offsets=offsets,
857869
token_indices=token_indices,
870+
chunk_offsets=chunk_offsets,
858871
T=T,
859872
H=H,
860873
HQ=HQ,
@@ -888,6 +901,15 @@ def parallel_nsa_compression_bwd(
888901
BK = triton.next_power_of_2(K)
889902
BV = min(128, triton.next_power_of_2(v.shape[-1]))
890903
NV = triton.cdiv(V, BV)
904+
if offsets is not None:
905+
lens = prepare_lens(offsets)
906+
chunk_indices = torch.cat([torch.arange(n) for n in triton.cdiv(triton.cdiv(lens, BS), BC).tolist()])
907+
chunk_indices = torch.stack([chunk_indices.eq(0).cumsum(0) - 1, chunk_indices], 1).to(offsets)
908+
chunk_offsets = prepare_chunk_offsets(offsets, BS)
909+
NC = len(chunk_indices)
910+
else:
911+
chunk_indices, chunk_offsets = None, None
912+
NC = triton.cdiv(triton.cdiv(T, BS), BC)
891913

892914
delta = parallel_nsa_bwd_preprocess(o, do)
893915

@@ -904,6 +926,7 @@ def parallel_nsa_compression_bwd(
904926
scale=scale,
905927
offsets=offsets,
906928
token_indices=token_indices,
929+
chunk_offsets=chunk_offsets,
907930
T=T,
908931
B=B,
909932
H=H,
@@ -918,15 +941,6 @@ def parallel_nsa_compression_bwd(
918941
)
919942
dq = dq.sum(0)
920943

921-
if offsets is not None:
922-
lens = prepare_lens(offsets)
923-
chunk_indices = torch.cat([torch.arange(n) for n in triton.cdiv(triton.cdiv(lens, BS), BC).tolist()])
924-
chunk_indices = torch.stack([chunk_indices.eq(0).cumsum(0) - 1, chunk_indices], 1).to(offsets)
925-
NC = len(chunk_indices)
926-
else:
927-
chunk_indices = None
928-
NC = triton.cdiv(triton.cdiv(T, BS), BC)
929-
930944
dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
931945
dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
932946

@@ -942,6 +956,7 @@ def parallel_nsa_compression_bwd(
942956
dv=dv,
943957
offsets=offsets,
944958
chunk_indices=chunk_indices,
959+
chunk_offsets=chunk_offsets,
945960
scale=scale,
946961
T=T,
947962
B=B,
@@ -1037,6 +1052,7 @@ def parallel_nsa_topk(
10371052

10381053
block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device)
10391054
token_indices = prepare_token_indices(offsets) if offsets is not None else None
1055+
chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None
10401056
grid = (T, B * H)
10411057
parallel_nsa_kernel_topk[grid](
10421058
q=q,
@@ -1046,6 +1062,7 @@ def parallel_nsa_topk(
10461062
block_indices=block_indices,
10471063
offsets=offsets,
10481064
token_indices=token_indices,
1065+
chunk_offsets=chunk_offsets,
10491066
T=T,
10501067
H=H,
10511068
HQ=HQ,

0 commit comments

Comments
 (0)