66
77# Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396
88
9-
109import triton
11- import triton .language .core as core
12- from triton .language .standard import _log2 , sum , zeros_like
10+ import triton .language as tl
1311
1412
1513@triton .jit
1614def _compare_and_swap (
1715 x ,
1816 ids ,
1917 flip ,
20- i : core .constexpr ,
21- n_dims : core .constexpr ,
18+ i : tl .constexpr ,
19+ n_dims : tl .constexpr ,
2220):
23- n_outer : core .constexpr = x .numel >> n_dims
24- shape : core .constexpr = [n_outer * 2 ** i , 2 , 2 ** (n_dims - i - 1 )]
25- y = core .reshape (x , shape )
21+ n_outer : tl .constexpr = x .numel >> n_dims
22+ shape : tl .constexpr = [n_outer * 2 ** i , 2 , 2 ** (n_dims - i - 1 )]
23+ y = tl .reshape (x , shape )
2624 # slice left/right with 'stride' 2**(n_dims - i - 1)
27- mask = core .arange (0 , 2 )[None , :, None ]
28- left = core .broadcast_to (sum (y * (1 - mask ), 1 )[:, None , :], shape ).to (y .dtype )
29- right = core .broadcast_to (sum (y * mask , 1 )[:, None , :], shape ).to (y .dtype )
30- left = core .reshape (left , x .shape )
31- right = core .reshape (right , x .shape )
25+ mask = tl .arange (0 , 2 )[None , :, None ]
26+ left = tl .broadcast_to (tl . sum (y * (1 - mask ), 1 )[:, None , :], shape ).to (y .dtype )
27+ right = tl .broadcast_to (tl . sum (y * mask , 1 )[:, None , :], shape ).to (y .dtype )
28+ left = tl .reshape (left , x .shape )
29+ right = tl .reshape (right , x .shape )
3230 # idx
33- y_idx = core .reshape (ids , shape )
34- left_idx = core .broadcast_to (sum (y_idx * (1 - mask ), 1 )[:, None , :], shape )
35- right_idx = core .broadcast_to (sum (y_idx * mask , 1 )[:, None , :], shape )
36- left_idx = core .reshape (left_idx , x .shape ).to (y_idx .dtype )
37- right_idx = core .reshape (right_idx , x .shape ).to (y_idx .dtype )
31+ y_idx = tl .reshape (ids , shape )
32+ left_idx = tl .broadcast_to (tl . sum (y_idx * (1 - mask ), 1 )[:, None , :], shape )
33+ right_idx = tl .broadcast_to (tl . sum (y_idx * mask , 1 )[:, None , :], shape )
34+ left_idx = tl .reshape (left_idx , x .shape ).to (y_idx .dtype )
35+ right_idx = tl .reshape (right_idx , x .shape ).to (y_idx .dtype )
3836 # actual compare-and-swap
39- idtype = core .get_int_dtype (bitwidth = x .dtype .primitive_bitwidth , signed = True )
37+ idtype = tl . core .get_int_dtype (bitwidth = x .dtype .primitive_bitwidth , signed = True )
4038 ileft = left .to (idtype , bitcast = True )
4139 iright = right .to (idtype , bitcast = True )
4240 ix = x .to (idtype , bitcast = True )
4341
4442 cond = (left > right ) != flip
45- ret = ix ^ core .where (cond , ileft ^ iright , zeros_like (ix ))
46- new_ids = ids ^ core .where (cond , left_idx ^ right_idx , zeros_like (ids ))
43+ ret = ix ^ tl .where (cond , ileft ^ iright , tl . zeros_like (ix ))
44+ new_ids = ids ^ tl .where (cond , left_idx ^ right_idx , tl . zeros_like (ids ))
4745 return ret .to (x .dtype , bitcast = True ), new_ids
4846
4947
5048@triton .jit
5149def _bitonic_merge (
5250 x ,
5351 ids ,
54- stage : core .constexpr ,
55- order : core .constexpr ,
56- n_dims : core .constexpr ,
52+ stage : tl .constexpr ,
53+ order : tl .constexpr ,
54+ n_dims : tl .constexpr ,
5755):
58- '''
59- order_type 0 == ascending
60- order_type 1 == descending
61- order_type 2 == alternating
62- '''
63- n_outer : core .constexpr = x .numel >> n_dims
64- core .static_assert (stage <= n_dims )
56+ n_outer : tl .constexpr = x .numel >> n_dims
57+ tl .static_assert (stage <= n_dims )
6558 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
6659 # descending order.
6760 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
6861 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
6962 # a stride of 2) at this stage
7063 if order == 2 :
71- shape : core .constexpr = [n_outer * 2 ** (n_dims - 1 - stage ), 2 , 2 ** stage ]
72- flip = core .reshape (core .broadcast_to (core .arange (0 , 2 )[None , :, None ], shape ), x .shape )
64+ shape : tl .constexpr = [n_outer * 2 ** (n_dims - 1 - stage ), 2 , 2 ** stage ]
65+ flip = tl .reshape (tl .broadcast_to (tl .arange (0 , 2 )[None , :, None ], shape ), x .shape )
7366 else :
7467 flip = order
7568 # perform `stage` rounds of `compare-and-swap`
76- for i in core .static_range (stage ):
69+ for i in tl .static_range (stage ):
7770 x , ids = _compare_and_swap (x , ids , flip , i + (n_dims - stage ), n_dims )
7871 return x , ids
7972
@@ -82,15 +75,15 @@ def _bitonic_merge(
8275def argsort (
8376 x ,
8477 ids ,
85- dim : core .constexpr = None ,
86- descending : core .constexpr = core .CONSTEXPR_0 ,
78+ dim : tl .constexpr = None ,
79+ descending : tl .constexpr = tl . core .CONSTEXPR_0 ,
8780):
8881 # handle default dimension or check that it is the most minor dim
89- _dim : core .constexpr = len (x .shape ) - 1 if dim is None else dim
90- core .static_assert (_dim == len (x .shape ) - 1 , "only minor dimension is currently supported" )
82+ _dim : tl .constexpr = len (x .shape ) - 1 if dim is None else dim
83+ tl .static_assert (_dim == len (x .shape ) - 1 , "only minor dimension is currently supported" )
9184 # iteratively run bitonic merge-sort steps
92- n_dims : core .constexpr = _log2 (x .shape [_dim ])
85+ n_dims : tl .constexpr = tl . log2 (x .shape [_dim ])
9386
94- for i in core .static_range (1 , n_dims + 1 ):
87+ for i in tl .static_range (1 , n_dims + 1 ):
9588 x , ids = _bitonic_merge (x , ids , i , 2 if i < n_dims else descending , n_dims )
9689 return x , ids
0 commit comments