Skip to content

Commit 28e1179

Browse files
committed
Fix shape mismatch when saving block indices
1 parent a60d291 commit 28e1179

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ def parallel_nsa_compression_bwd_kernel_dkv(
330330

331331

332332
@triton.heuristics({
333-
'BC2': lambda args: args['BC'] // 2,
334333
'USE_OFFSETS': lambda args: args['offsets'] is not None
335334
})
336335
@triton.autotune(
@@ -358,7 +357,6 @@ def parallel_nsa_kernel_topk(
358357
K: tl.constexpr,
359358
S: tl.constexpr,
360359
BC: tl.constexpr,
361-
BC2: tl.constexpr,
362360
BS: tl.constexpr,
363361
BK: tl.constexpr,
364362
USE_OFFSETS: tl.constexpr,
@@ -428,7 +426,7 @@ def parallel_nsa_kernel_topk(
428426
# [BC]
429427
b_i = tl.full([BC], -1, dtype=tl.float32)
430428
o_i = tl.zeros([BC], dtype=tl.int32)
431-
m_i = tl.arange(0, BC) < BC2
429+
m_i = tl.arange(0, BC) < BC//2
432430
for i_c in range(0, i_t // BS + 1, BC):
433431
o_c = i_c + tl.arange(0, BC)
434432

@@ -457,8 +455,8 @@ def parallel_nsa_kernel_topk(
457455
else:
458456
b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims)
459457

460-
m_top = tl.arange(0, 2) == 0
461-
b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [2, BC2]), 0)
458+
m_top = tl.arange(0, BC//S) == 0
459+
b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0)
462460

463461
p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,))
464462
tl.store(p_b, b_top.to(p_b.dtype.element_ty))

0 commit comments

Comments
 (0)