77import torch
88import triton
99import triton .language as tl
10- import triton .language .core as core
1110from einops import rearrange
1211
1312from fla .ops .common .utils import (prepare_chunk_indices , prepare_chunk_offsets ,
3130})
3231@triton .autotune (
3332 configs = [
34- triton .Config ({}, num_warps = num_warps , num_stages = num_stages )
35- for num_warps in [1 , 2 , 4 , 8 ]
36- for num_stages in [2 , 3 , 4 , 5 ]
33+ triton .Config ({}, num_warps = num_warps )
34+ for num_warps in [1 , 2 , 4 ]
3735 ],
3836 key = ['BS' , 'BK' , 'BV' ],
3937)
@@ -134,9 +132,8 @@ def parallel_nsa_compression_fwd_kernel(
134132})
135133@triton .autotune (
136134 configs = [
137- triton .Config ({}, num_warps = num_warps , num_stages = num_stages )
138- for num_warps in [1 , 2 , 4 , 8 ]
139- for num_stages in [2 , 3 , 4 , 5 ]
135+ triton .Config ({}, num_warps = num_warps )
136+ for num_warps in [1 , 2 , 4 ]
140137 ],
141138 key = ['BS' , 'BK' , 'BV' ],
142139)
@@ -238,7 +235,7 @@ def parallel_nsa_compression_bwd_kernel_dq(
238235@triton .autotune (
239236 configs = [
240237 triton .Config ({}, num_warps = num_warps )
241- for num_warps in [1 , 2 , 4 , 8 ]
238+ for num_warps in [1 , 2 , 4 ]
242239 ],
243240 key = ['BS' , 'BK' , 'BV' ],
244241)
@@ -334,9 +331,8 @@ def parallel_nsa_compression_bwd_kernel_dkv(
334331})
335332@triton .autotune (
336333 configs = [
337- triton .Config ({}, num_warps = num_warps , num_stages = num_stages )
338- for num_warps in [1 , 2 , 4 , 8 ]
339- for num_stages in [2 , 3 , 4 , 5 ]
334+ triton .Config ({}, num_warps = num_warps )
335+ for num_warps in [1 , 2 , 4 ]
340336 ],
341337 key = ['BS' , 'BK' ],
342338)
@@ -443,8 +439,8 @@ def parallel_nsa_kernel_topk(
443439 b_i , b_ip = tl .sum (b_p , 0 ), b_i
444440 o_i , o_ip = tl .where (o_c <= i_t // BS , o_c + 1 , 0 ), o_i
445441
446- n_dims : core .constexpr = tl .standard ._log2 (b_i .shape [0 ])
447- for i in core .static_range (1 , n_dims ):
442+ n_dims : tl .constexpr = tl .standard ._log2 (b_i .shape [0 ])
443+ for i in tl .static_range (1 , n_dims ):
448444 b_i , o_i = _bitonic_merge (b_i , o_i .to (tl .int32 ), i , 2 , n_dims )
449445
450446 if i_c != 0 :
@@ -469,7 +465,7 @@ def parallel_nsa_kernel_topk(
469465@triton .autotune (
470466 configs = [
471467 triton .Config ({}, num_warps = num_warps )
472- for num_warps in [1 , 2 , 4 , 8 ]
468+ for num_warps in [1 , 2 , 4 ]
473469 ],
474470 key = ['BS' , 'BK' , 'BV' ],
475471)
@@ -613,9 +609,8 @@ def parallel_nsa_bwd_kernel_preprocess(
613609})
614610@triton .autotune (
615611 configs = [
616- triton .Config ({}, num_warps = num_warps , num_stages = num_stages )
617- for num_warps in [1 , 2 , 4 , 8 ]
618- for num_stages in [2 , 3 , 4 , 5 ]
612+ triton .Config ({}, num_warps = num_warps )
613+ for num_warps in [1 , 2 , 4 ]
619614 ],
620615 key = ['BS' , 'BK' , 'BV' ],
621616)
@@ -721,9 +716,8 @@ def parallel_nsa_bwd_kernel_dq(
721716})
722717@triton .autotune (
723718 configs = [
724- triton .Config ({}, num_warps = num_warps , num_stages = num_stages )
725- for num_warps in [1 , 2 , 4 , 8 ]
726- for num_stages in [2 , 3 , 4 , 5 ]
719+ triton .Config ({}, num_warps = num_warps )
720+ for num_warps in [1 , 2 , 4 ]
727721 ],
728722 key = ['BS' , 'BK' , 'BV' ],
729723)
0 commit comments