Skip to content

Commit f56759d

Browse files
committed
Autotune args for better performance
1 parent 0125dfe commit f56759d

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
})
3333
@triton.autotune(
3434
configs=[
35-
triton.Config({}, num_warps=num_warps)
36-
for num_warps in [1, 2, 4]
35+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
36+
for num_warps in [1, 2, 4, 8]
37+
for num_stages in [2, 3, 4, 5]
3738
],
3839
key=['BS', 'BK', 'BV'],
3940
)
@@ -131,8 +132,9 @@ def parallel_nsa_compression_fwd_kernel(
131132
})
132133
@triton.autotune(
133134
configs=[
134-
triton.Config({}, num_warps=num_warps)
135+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
135136
for num_warps in [1, 2, 4, 8]
137+
for num_stages in [2, 3, 4, 5]
136138
],
137139
key=['BS', 'BK', 'BV'],
138140
)
@@ -327,8 +329,9 @@ def parallel_nsa_compression_bwd_kernel_dkv(
327329
})
328330
@triton.autotune(
329331
configs=[
330-
triton.Config({}, num_warps=num_warps)
331-
for num_warps in [1, 2, 4]
332+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
333+
for num_warps in [1, 2, 4, 8]
334+
for num_stages in [2, 3, 4, 5]
332335
],
333336
key=['BS', 'BK'],
334337
)
@@ -603,8 +606,9 @@ def parallel_nsa_bwd_kernel_preprocess(
603606
})
604607
@triton.autotune(
605608
configs=[
606-
triton.Config({}, num_warps=num_warps)
609+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
607610
for num_warps in [1, 2, 4, 8]
611+
for num_stages in [2, 3, 4, 5]
608612
],
609613
key=['BS', 'BK', 'BV'],
610614
)
@@ -710,8 +714,9 @@ def parallel_nsa_bwd_kernel_dq(
710714
})
711715
@triton.autotune(
712716
configs=[
713-
triton.Config({}, num_warps=num_warps)
717+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
714718
for num_warps in [1, 2, 4, 8]
719+
for num_stages in [2, 3, 4, 5]
715720
],
716721
key=['BS', 'BK', 'BV'],
717722
)
@@ -1030,7 +1035,7 @@ def parallel_nsa_topk(
10301035
BC = BS = block_size
10311036
BK = triton.next_power_of_2(K)
10321037

1033-
block_indices = torch.zeros(B, T, H, S, dtype=torch.long, device=q.device)
1038+
block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device)
10341039
token_indices = prepare_token_indices(offsets) if offsets is not None else None
10351040
grid = (T, B * H)
10361041
parallel_nsa_kernel_topk[grid](

0 commit comments

Comments
 (0)