Skip to content

Commit d146fa6

Browse files
committed
Adjust autotuning args
1 parent 28e1179 commit d146fa6

File tree

3 files changed

+48
-61
lines changed

3 files changed

+48
-61
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
import triton
99
import triton.language as tl
10-
import triton.language.core as core
1110
from einops import rearrange
1211

1312
from fla.ops.common.utils import (prepare_chunk_indices, prepare_chunk_offsets,
@@ -31,9 +30,8 @@
3130
})
3231
@triton.autotune(
3332
configs=[
34-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
35-
for num_warps in [1, 2, 4, 8]
36-
for num_stages in [2, 3, 4, 5]
33+
triton.Config({}, num_warps=num_warps)
34+
for num_warps in [1, 2, 4]
3735
],
3836
key=['BS', 'BK', 'BV'],
3937
)
@@ -134,9 +132,8 @@ def parallel_nsa_compression_fwd_kernel(
134132
})
135133
@triton.autotune(
136134
configs=[
137-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
138-
for num_warps in [1, 2, 4, 8]
139-
for num_stages in [2, 3, 4, 5]
135+
triton.Config({}, num_warps=num_warps)
136+
for num_warps in [1, 2, 4]
140137
],
141138
key=['BS', 'BK', 'BV'],
142139
)
@@ -238,7 +235,7 @@ def parallel_nsa_compression_bwd_kernel_dq(
238235
@triton.autotune(
239236
configs=[
240237
triton.Config({}, num_warps=num_warps)
241-
for num_warps in [1, 2, 4, 8]
238+
for num_warps in [1, 2, 4]
242239
],
243240
key=['BS', 'BK', 'BV'],
244241
)
@@ -334,9 +331,8 @@ def parallel_nsa_compression_bwd_kernel_dkv(
334331
})
335332
@triton.autotune(
336333
configs=[
337-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
338-
for num_warps in [1, 2, 4, 8]
339-
for num_stages in [2, 3, 4, 5]
334+
triton.Config({}, num_warps=num_warps)
335+
for num_warps in [1, 2, 4]
340336
],
341337
key=['BS', 'BK'],
342338
)
@@ -443,8 +439,8 @@ def parallel_nsa_kernel_topk(
443439
b_i, b_ip = tl.sum(b_p, 0), b_i
444440
o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i
445441

446-
n_dims: core.constexpr = tl.standard._log2(b_i.shape[0])
447-
for i in core.static_range(1, n_dims):
442+
n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0])
443+
for i in tl.static_range(1, n_dims):
448444
b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims)
449445

450446
if i_c != 0:
@@ -469,7 +465,7 @@ def parallel_nsa_kernel_topk(
469465
@triton.autotune(
470466
configs=[
471467
triton.Config({}, num_warps=num_warps)
472-
for num_warps in [1, 2, 4, 8]
468+
for num_warps in [1, 2, 4]
473469
],
474470
key=['BS', 'BK', 'BV'],
475471
)
@@ -613,9 +609,8 @@ def parallel_nsa_bwd_kernel_preprocess(
613609
})
614610
@triton.autotune(
615611
configs=[
616-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
617-
for num_warps in [1, 2, 4, 8]
618-
for num_stages in [2, 3, 4, 5]
612+
triton.Config({}, num_warps=num_warps)
613+
for num_warps in [1, 2, 4]
619614
],
620615
key=['BS', 'BK', 'BV'],
621616
)
@@ -721,9 +716,8 @@ def parallel_nsa_bwd_kernel_dq(
721716
})
722717
@triton.autotune(
723718
configs=[
724-
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
725-
for num_warps in [1, 2, 4, 8]
726-
for num_stages in [2, 3, 4, 5]
719+
triton.Config({}, num_warps=num_warps)
720+
for num_warps in [1, 2, 4]
727721
],
728722
key=['BS', 'BK', 'BV'],
729723
)

native_sparse_attention/ops/utils.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,74 +6,67 @@
66

77
# Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396
88

9-
109
import triton
11-
import triton.language.core as core
12-
from triton.language.standard import _log2, sum, zeros_like
10+
import triton.language as tl
1311

1412

1513
@triton.jit
1614
def _compare_and_swap(
1715
x,
1816
ids,
1917
flip,
20-
i: core.constexpr,
21-
n_dims: core.constexpr,
18+
i: tl.constexpr,
19+
n_dims: tl.constexpr,
2220
):
23-
n_outer: core.constexpr = x.numel >> n_dims
24-
shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
25-
y = core.reshape(x, shape)
21+
n_outer: tl.constexpr = x.numel >> n_dims
22+
shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
23+
y = tl.reshape(x, shape)
2624
# slice left/right with 'stride' 2**(n_dims - i - 1)
27-
mask = core.arange(0, 2)[None, :, None]
28-
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
29-
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
30-
left = core.reshape(left, x.shape)
31-
right = core.reshape(right, x.shape)
25+
mask = tl.arange(0, 2)[None, :, None]
26+
left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
27+
right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
28+
left = tl.reshape(left, x.shape)
29+
right = tl.reshape(right, x.shape)
3230
# idx
33-
y_idx = core.reshape(ids, shape)
34-
left_idx = core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape)
35-
right_idx = core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape)
36-
left_idx = core.reshape(left_idx, x.shape).to(y_idx.dtype)
37-
right_idx = core.reshape(right_idx, x.shape).to(y_idx.dtype)
31+
y_idx = tl.reshape(ids, shape)
32+
left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape)
33+
right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape)
34+
left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype)
35+
right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype)
3836
# actual compare-and-swap
39-
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
37+
idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
4038
ileft = left.to(idtype, bitcast=True)
4139
iright = right.to(idtype, bitcast=True)
4240
ix = x.to(idtype, bitcast=True)
4341

4442
cond = (left > right) != flip
45-
ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
46-
new_ids = ids ^ core.where(cond, left_idx ^ right_idx, zeros_like(ids))
43+
ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
44+
new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids))
4745
return ret.to(x.dtype, bitcast=True), new_ids
4846

4947

5048
@triton.jit
5149
def _bitonic_merge(
5250
x,
5351
ids,
54-
stage: core.constexpr,
55-
order: core.constexpr,
56-
n_dims: core.constexpr,
52+
stage: tl.constexpr,
53+
order: tl.constexpr,
54+
n_dims: tl.constexpr,
5755
):
58-
'''
59-
order_type 0 == ascending
60-
order_type 1 == descending
61-
order_type 2 == alternating
62-
'''
63-
n_outer: core.constexpr = x.numel >> n_dims
64-
core.static_assert(stage <= n_dims)
56+
n_outer: tl.constexpr = x.numel >> n_dims
57+
tl.static_assert(stage <= n_dims)
6558
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
6659
# descending order.
6760
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
6861
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
6962
# a stride of 2) at this stage
7063
if order == 2:
71-
shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
72-
flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
64+
shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
65+
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
7366
else:
7467
flip = order
7568
# perform `stage` rounds of `compare-and-swap`
76-
for i in core.static_range(stage):
69+
for i in tl.static_range(stage):
7770
x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
7871
return x, ids
7972

@@ -82,15 +75,15 @@ def _bitonic_merge(
8275
def argsort(
8376
x,
8477
ids,
85-
dim: core.constexpr = None,
86-
descending: core.constexpr = core.CONSTEXPR_0,
78+
dim: tl.constexpr = None,
79+
descending: tl.constexpr = tl.core.CONSTEXPR_0,
8780
):
8881
# handle default dimension or check that it is the most minor dim
89-
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
90-
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
82+
_dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
83+
tl.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
9184
# iteratively run bitonic merge-sort steps
92-
n_dims: core.constexpr = _log2(x.shape[_dim])
85+
n_dims: tl.constexpr = tl.log2(x.shape[_dim])
9386

94-
for i in core.static_range(1, n_dims + 1):
87+
for i in tl.static_range(1, n_dims + 1):
9588
x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
9689
return x, ids

0 commit comments

Comments
 (0)