@@ -346,6 +346,12 @@ def partition(
346
346
The ordering of the elements in the two partitions on the either side of
347
347
the k-th element in the output array is undefined.
348
348
349
+ Notes:
350
+ If `xp` implements `partition` or an equivalent method (e.g. topk for torch),
351
+ complexity will likely be O(n).
352
+ If not, this function simply calls `xp.sort` and complexity is O(n log n).
353
+
354
+ Parameters
349
355
----------
350
356
a : Array
351
357
Input array.
@@ -361,11 +367,6 @@ def partition(
361
367
-------
362
368
partitioned_array
363
369
Array of the same type and shape as `a`.
364
-
365
- Notes:
366
- If `xp` implements `partition` or an equivalent method (e.g. topk for torch),
367
- complexity will likely be O(n).
368
- If not, this function simply calls `xp.sort` and complexity is O(n log n).
369
370
"""
370
371
# Validate inputs.
371
372
if xp is None :
@@ -389,26 +390,32 @@ def partition(
389
390
if not (axis == - 1 or axis == a .ndim - 1 ):
390
391
a = xp .transpose (a , axis , - 1 )
391
392
392
- # Get smallest `kth` elements along axis
393
- kth += 1 # HACK: we use a non-specified behavior of torch.topk:
394
- # in `a_left`, the element in the last position is the max
395
- a_left , indices = xp .topk (a , kth , dim = - 1 , largest = False , sorted = False )
393
+ out = xp .empty_like (a )
394
+ ranks = xp .arange (a .shape [- 1 ]).expand_as (a )
395
+
396
+ split_value , indices = xp .kthvalue (a , kth + 1 , keepdim = True )
397
+ del indices
396
398
397
- # Build a mask to remove the selected elements
398
- mask_right = xp .ones (a .shape , dtype = bool )
399
- mask_right .scatter_ (dim = - 1 , index = indices , value = False )
399
+ # fill the left-side of the partition
400
+ mask_src = a < split_value
401
+ n_left = mask_src .sum (dim = - 1 , keepdim = True )
402
+ mask_dest = ranks < n_left
403
+ out [mask_dest ] = a [mask_src ]
400
404
401
- # Remaining elements along axis
402
- a_right = a [mask_right ] # 1-d array
405
+ # fill the middle of the partition
406
+ mask_src = a == split_value
407
+ n_left += mask_src .sum (dim = - 1 , keepdim = True )
408
+ mask_dest ^= ranks < n_left
409
+ out [mask_dest ] = a [mask_src ]
403
410
404
- # Reshape. This is valid only because we work on the last axis
405
- a_right = xp .reshape (a_right , shape = (* a .shape [:- 1 ], - 1 ))
411
+ # fill the right-side of the partition
412
+ mask_src = a > split_value
413
+ mask_dest = ranks >= n_left
414
+ out [mask_dest ] = a [mask_src ]
406
415
407
- # Concatenate the two parts along axis
408
- partitioned_array = xp .cat ((a_left , a_right ), dim = - 1 )
409
416
if not (axis == - 1 or axis == a .ndim - 1 ):
410
- partitioned_array = xp .transpose (partitioned_array , axis , - 1 )
411
- return partitioned_array
417
+ out = xp .transpose (out , axis , - 1 )
418
+ return out
412
419
413
420
# Note: dask topk/argtopk sort the return values, so it's
414
421
# not much more efficient than sorting everything when
@@ -427,9 +434,15 @@ def argpartition(
427
434
) -> Array :
428
435
"""
429
436
Perform an indirect partition along the given axis.
437
+
430
438
It returns an array of indices of the same shape as `a` that
431
439
index data along the given axis in partitioned order.
432
440
441
+ Notes:
442
+ If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch),
443
+ complexity will likely be O(n).
444
+ If not, this function simply calls `xp.argsort` and complexity is O(n log n).
445
+
433
446
Parameters
434
447
----------
435
448
a : Array
@@ -446,11 +459,6 @@ def argpartition(
446
459
-------
447
460
index_array
448
461
Array of indices that partition `a` along the specified axis.
449
-
450
- Notes:
451
- If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch),
452
- complexity will likely be O(n).
453
- If not, this function simply calls `xp.argsort` and complexity is O(n log n).
454
462
"""
455
463
# Validate inputs.
456
464
if xp is None :
@@ -478,20 +486,29 @@ def argpartition(
478
486
if not (axis == - 1 or axis == a .ndim - 1 ):
479
487
a = xp .transpose (a , axis , - 1 )
480
488
481
- kth += 1 # HACK
482
- _ , indices_left = xp .topk (a , kth , dim = - 1 , largest = False , sorted = False )
489
+ ranks = xp .arange (a .shape [- 1 ]).expand_as (a )
490
+ out = xp .empty_like (ranks )
491
+
492
+ split_value , indices = xp .kthvalue (a , kth + 1 , keepdim = True )
493
+ del indices
494
+
495
+ mask_src = a < split_value
496
+ n_left = mask_src .sum (dim = - 1 , keepdim = True )
497
+ mask_dest = ranks < n_left
498
+ out [mask_dest ] = ranks [mask_src ]
483
499
484
- mask_right = xp .ones (a .shape , dtype = bool )
485
- mask_right .scatter_ (dim = - 1 , index = indices_left , value = False )
500
+ mask_src = a == split_value
501
+ n_left += mask_src .sum (dim = - 1 , keepdim = True )
502
+ mask_dest ^= ranks < n_left
503
+ out [mask_dest ] = ranks [mask_src ]
486
504
487
- indices_right = xp .nonzero (mask_right )[- 1 ]
488
- indices_right = xp .reshape (indices_right , shape = (* a .shape [:- 1 ], - 1 ))
505
+ mask_src = a > split_value
506
+ mask_dest = ranks >= n_left
507
+ out [mask_dest ] = ranks [mask_src ]
489
508
490
- # Concatenate the two parts along axis
491
- index_array = xp .cat ((indices_left , indices_right ), dim = - 1 )
492
509
if not (axis == - 1 or axis == a .ndim - 1 ):
493
- index_array = xp .transpose (index_array , axis , - 1 )
494
- return index_array
510
+ out = xp .transpose (out , axis , - 1 )
511
+ return out
495
512
496
513
# Note: dask topk/argtopk sort the return values, so it's
497
514
# not much more efficient than sorting everything when
0 commit comments