Skip to content

Commit 7b099fc

Browse files
committed
Fix reshape bugs
1 parent 4fe828c commit 7b099fc

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,10 @@ def parallel_nsa_compression_bwd_kernel_dkv(
266266

267267
p_k = tl.make_block_ptr(k + (tl.cdiv(bos, BS) * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
268268
p_v = tl.make_block_ptr(v + (tl.cdiv(bos, BS) * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
269-
p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + tl.cdiv(bos, BS) * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
270-
p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + tl.cdiv(bos, BS) * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
269+
p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + tl.cdiv(bos, BS) * H + i_h) * K,
270+
(TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
271+
p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + tl.cdiv(bos, BS) * H + i_h) * V,
272+
(TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
271273

272274
# [BC, BK]
273275
b_k = tl.load(p_k, boundary_check=(0, 1))
@@ -310,6 +312,7 @@ def parallel_nsa_compression_bwd_kernel_dkv(
310312

311313

312314
@triton.heuristics({
315+
'BC2': lambda args: args['BC'] // 2,
313316
'USE_OFFSETS': lambda args: args['offsets'] is not None
314317
})
315318
@triton.autotune(
@@ -335,6 +338,7 @@ def parallel_nsa_kernel_topk(
335338
K: tl.constexpr,
336339
S: tl.constexpr,
337340
BC: tl.constexpr,
341+
BC2: tl.constexpr,
338342
BS: tl.constexpr,
339343
BK: tl.constexpr,
340344
USE_OFFSETS: tl.constexpr,
@@ -402,7 +406,7 @@ def parallel_nsa_kernel_topk(
402406
# [BC]
403407
b_i = tl.full([BC], -1, dtype=tl.float32)
404408
o_i = tl.zeros([BC], dtype=tl.int32)
405-
m_i = tl.arange(0, BC) < S
409+
m_i = tl.arange(0, BC) < BC2
406410
for i_c in range(0, i_t // BS + 1, BC):
407411
o_c = i_c + tl.arange(0, BC)
408412

@@ -432,8 +436,7 @@ def parallel_nsa_kernel_topk(
432436
b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims)
433437

434438
m_top = tl.arange(0, 2) == 0
435-
436-
b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [2, S]), 0)
439+
b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [2, BC2]), 0)
437440

438441
p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,))
439442
tl.store(p_b, b_top.to(p_b.dtype.element_ty))
@@ -1369,6 +1372,7 @@ def parallel_nsa_bwd(
13691372
dk = dk.sum(0)
13701373
return dq, dk, dv
13711374

1375+
13721376
@torch.compile
13731377
class ParallelNSAFunction(torch.autograd.Function):
13741378

@@ -1609,7 +1613,6 @@ def parallel_nsa_with_compression(
16091613
offsets=cu_seqlens
16101614
)
16111615
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
1612-
pos = torch.arange(0, q.shape[1]).view(1, q.shape[1], 1).to(q.device).to(block_indices.dtype)
16131616
o = o_slc * g_slc.unsqueeze(-1)
16141617
if o_cmp is not None:
16151618
o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1))

0 commit comments

Comments
 (0)