Skip to content

Commit b1d78f5

Browse files
committed
Adjust autotuning args
1 parent 28e1179 commit b1d78f5

File tree

3 files changed

+15
-26
lines changed

3 files changed

+15
-26
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
import triton
99
import triton.language as tl
10-
import triton.language.core as core
1110
from einops import rearrange
1211

1312
from fla.ops.common.utils import (prepare_chunk_indices, prepare_chunk_offsets,
@@ -31,9 +30,8 @@
3130
})
3231
@triton.autotune(
3332
configs=[
34-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
35-
for num_warps in [1, 2, 4, 8]
36-
for num_stages in [2, 3, 4, 5]
33+
triton.Config({}, num_warps=num_warps)
34+
for num_warps in [1, 2, 4]
3735
],
3836
key=['BS', 'BK', 'BV'],
3937
)
@@ -134,9 +132,8 @@ def parallel_nsa_compression_fwd_kernel(
134132
})
135133
@triton.autotune(
136134
configs=[
137-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
138-
for num_warps in [1, 2, 4, 8]
139-
for num_stages in [2, 3, 4, 5]
135+
triton.Config({}, num_warps=num_warps)
136+
for num_warps in [1, 2, 4]
140137
],
141138
key=['BS', 'BK', 'BV'],
142139
)
@@ -238,7 +235,7 @@ def parallel_nsa_compression_bwd_kernel_dq(
238235
@triton.autotune(
239236
configs=[
240237
triton.Config({}, num_warps=num_warps)
241-
for num_warps in [1, 2, 4, 8]
238+
for num_warps in [1, 2, 4]
242239
],
243240
key=['BS', 'BK', 'BV'],
244241
)
@@ -334,9 +331,8 @@ def parallel_nsa_compression_bwd_kernel_dkv(
334331
})
335332
@triton.autotune(
336333
configs=[
337-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
338-
for num_warps in [1, 2, 4, 8]
339-
for num_stages in [2, 3, 4, 5]
334+
triton.Config({}, num_warps=num_warps)
335+
for num_warps in [1, 2, 4]
340336
],
341337
key=['BS', 'BK'],
342338
)
@@ -443,8 +439,8 @@ def parallel_nsa_kernel_topk(
443439
b_i, b_ip = tl.sum(b_p, 0), b_i
444440
o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i
445441

446-
n_dims: core.constexpr = tl.standard._log2(b_i.shape[0])
447-
for i in core.static_range(1, n_dims):
442+
n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0])
443+
for i in tl.static_range(1, n_dims):
448444
b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims)
449445

450446
if i_c != 0:
@@ -469,7 +465,7 @@ def parallel_nsa_kernel_topk(
469465
@triton.autotune(
470466
configs=[
471467
triton.Config({}, num_warps=num_warps)
472-
for num_warps in [1, 2, 4, 8]
468+
for num_warps in [1, 2, 4]
473469
],
474470
key=['BS', 'BK', 'BV'],
475471
)
@@ -613,9 +609,8 @@ def parallel_nsa_bwd_kernel_preprocess(
613609
})
614610
@triton.autotune(
615611
configs=[
616-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
617-
for num_warps in [1, 2, 4, 8]
618-
for num_stages in [2, 3, 4, 5]
612+
triton.Config({}, num_warps=num_warps)
613+
for num_warps in [1, 2, 4]
619614
],
620615
key=['BS', 'BK', 'BV'],
621616
)
@@ -721,9 +716,8 @@ def parallel_nsa_bwd_kernel_dq(
721716
})
722717
@triton.autotune(
723718
configs=[
724-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
725-
for num_warps in [1, 2, 4, 8]
726-
for num_stages in [2, 3, 4, 5]
719+
triton.Config({}, num_warps=num_warps)
720+
for num_warps in [1, 2, 4]
727721
],
728722
key=['BS', 'BK', 'BV'],
729723
)

native_sparse_attention/ops/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def _bitonic_merge(
5555
order: core.constexpr,
5656
n_dims: core.constexpr,
5757
):
58-
'''
59-
order_type 0 == ascending
60-
order_type 1 == descending
61-
order_type 2 == alternating
62-
'''
6358
n_outer: core.constexpr = x.numel >> n_dims
6459
core.static_assert(stage <= n_dims)
6560
# flip denotes whether to re-arrange sub-sequences of elements in ascending or

0 commit comments

Comments
 (0)