99
1010def _log2 (i : core .constexpr ):
1111 log2 = 0
12- n = core . constexpr ( i ) .value
12+ n = i .value
1313 while n > 1 :
1414 n >>= 1
1515 log2 += 1
@@ -338,19 +338,20 @@ def cumprod(input, axis=0, reverse=False):
338338def _compare_and_swap (x , flip , i : core .constexpr , n_dims : core .constexpr ):
339339 n_outer : core .constexpr = x .numel >> n_dims
340340 shape : core .constexpr = [n_outer * 2 ** i , 2 , 2 ** (n_dims - i - 1 )]
341-
342- # flip along middle dimension (the bitwise XORs will be optimised away):
341+ y = core .reshape (x , shape )
342+ # slice left/right with 'stride' 2**(n_dims - i - 1)
343+ mask = core .arange (0 , 2 )[None , :, None ]
344+ left = core .broadcast_to (sum (y * (1 - mask ), 1 )[:, None , :], shape ).to (y .dtype )
345+ right = core .broadcast_to (sum (y * mask , 1 )[:, None , :], shape ).to (y .dtype )
346+ left = core .reshape (left , x .shape )
347+ right = core .reshape (right , x .shape )
348+ # actual compare-and-swap
343349 idtype = core .get_int_dtype (bitwidth = x .dtype .primitive_bitwidth , signed = True )
344- ix = core .reshape (x , shape ).to (idtype , bitcast = True )
345- iy = ix ^ xor_sum (ix , 1 , True )
346- y = core .reshape (iy .to (x .dtype , bitcast = True ), x .shape )
347-
348- # determines whether we are in the right (rather than left) position along the axis:
349- is_right = core .reshape (core .broadcast_to (core .arange (0 , 2 )[None , :, None ], shape ), x .shape )
350-
351- # conditional swap:
352- ret = core .where ((x > y ) != (flip ^ is_right ), y , x )
353- return ret
350+ ileft = left .to (idtype , bitcast = True )
351+ iright = right .to (idtype , bitcast = True )
352+ ix = x .to (idtype , bitcast = True )
353+ ret = ix ^ core .where ((left > right ) != flip , ileft ^ iright , zeros_like (ix ))
354+ return ret .to (x .dtype , bitcast = True )
354355
355356
356357@jit
@@ -361,14 +362,14 @@ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core
361362 order_type 2 == alternating
362363 '''
363364 n_outer : core .constexpr = x .numel >> n_dims
365+ core .static_assert (stage <= n_dims )
364366 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
365367 # descending order.
366368 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
367369 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
368370 # a stride of 2) at this stage
369371 if order == 2 :
370- core .static_assert (stage <= (n_dims ))
371- shape : core .constexpr = [n_outer * 2 ** (n_dims - 1 - stage ), 2 , 2 ** (stage )]
372+ shape : core .constexpr = [n_outer * 2 ** (n_dims - 1 - stage ), 2 , 2 ** stage ]
372373 flip = core .reshape (core .broadcast_to (core .arange (0 , 2 )[None , :, None ], shape ), x .shape )
373374 else :
374375 flip = order
@@ -378,47 +379,30 @@ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core
378379 return x
379380
380381
382+ @core ._tensor_member_fn
381383@jit
382- def sort_impl ( x , k : core . constexpr = None , dim : core .constexpr = None , descending : core .constexpr = core .CONSTEXPR_0 ):
384+ def sort ( x , dim : core .constexpr = None , descending : core .constexpr = core .CONSTEXPR_0 ):
383385 """
384386 Sorts a tensor along a specified dimension.
385387
386388 :param x: The input tensor to be sorted.
387389 :type x: Tensor
388390 :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
389391 :type dim: int, optional
390- :param k: the number of top elements to select. If none, assume k = x.shape[dim]
391- :type k: int, optional
392392 :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
393393 :type descending: bool, optional
394394 """
395395 # handle default dimension or check that it is the most minor dim
396396 _dim : core .constexpr = len (x .shape ) - 1 if dim is None else dim
397397 core .static_assert (_dim == len (x .shape ) - 1 , "only minor dimension is currently supported" )
398398 # iteratively run bitonic merge-sort steps
399- n_outer : core .constexpr = x .numel >> _log2 (x .shape [_dim ])
400- log_n : core .constexpr = _log2 (x .shape [_dim ])
401- log_k : core .constexpr = log_n if k is None else _log2 (k )
402- for i in core .static_range (1 , log_k + 1 ):
403- x = _bitonic_merge (x , i , 2 if i < log_n else descending , log_n )
404- # select top k elements using bitonic top-k
405- # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
406- for i in core .static_range (log_k + 1 , log_n + 1 ):
407- x = core .reshape (x , [n_outer * 2 ** (log_n - i ), 2 , 2 ** log_k ])
408- x = max (x , axis = 1 ) if descending else min (x , axis = 1 )
409- x = core .reshape (x , [n_outer , 2 ** (log_n - i + log_k )])
410- x = _bitonic_merge (x , log_k , 2 if i < log_n else descending , _log2 (x .shape [_dim ]))
399+ n_dims : core .constexpr = _log2 (x .shape [_dim ])
400+ for i in core .static_range (1 , n_dims + 1 ):
401+ x = _bitonic_merge (x , i , 2 if i < n_dims else descending , n_dims )
411402 return x
412403
413404
414- @jit
415- def sort (x , dim : core .constexpr = None , descending : core .constexpr = core .CONSTEXPR_0 ):
416- return sort_impl (x , dim = dim , descending = descending )
417-
418-
419- @jit
420- def topk (x , k : core .constexpr , dim : core .constexpr = None ):
421- return sort_impl (x , k = k , dim = dim , descending = True )
405+ # flip
422406
423407
424408def _get_flip_dim (dim , shape ):
@@ -450,8 +434,14 @@ def flip(x, dim=None):
450434
451435 idtype = core .get_int_dtype (bitwidth = x .dtype .primitive_bitwidth , signed = True )
452436 y = core .reshape (x .to (idtype , bitcast = True ), [2 ] * steps )
437+ y = core .expand_dims (y , start )
438+ flip = (core .arange (0 , 2 )[:, None ] == 1 - core .arange (0 , 2 ))
453439 for i in core .static_range (start , steps ):
454- y = y ^ xor_sum (y , i , True )
440+ flip2 = flip
441+ for j in core .static_range (0 , steps + 1 ):
442+ if j != i and j != i + 1 :
443+ flip2 = core .expand_dims (flip2 , j )
444+ y = sum (y * flip2 , i + 1 , keep_dims = True , dtype = y .dtype )
455445 x = core .reshape (y , x .shape ).to (x .dtype , bitcast = True )
456446 return x
457447
0 commit comments