@@ -331,6 +331,8 @@ def pad(
331
331
def partition (
332
332
a : Array ,
333
333
kth : int ,
334
+ / ,
335
+ axis : int | None = - 1 ,
334
336
* ,
335
337
xp : ModuleType | None = None ,
336
338
) -> Array :
@@ -343,6 +345,9 @@ def partition(
343
345
Input array.
344
346
kth : int
345
347
Element index to partition by.
348
+ axis : int, optional
349
+ Axis along which to partition. The default is -1 (the last axis).
350
+ If None, the flattened array is used.
346
351
xp : array_namespace, optional
347
352
The standard-compatible namespace for `x`. Default: infer.
348
353
@@ -354,36 +359,61 @@ def partition(
354
359
# Validate inputs.
355
360
if xp is None :
356
361
xp = array_namespace (a )
357
- if a .ndim != 1 :
358
- msg = "only 1-dimensional arrays are currently supported"
359
- raise NotImplementedError (msg )
362
+ if a .ndim < 1 :
363
+ msg = "`a` must be at least 1-dimensional"
364
+ raise TypeError (msg )
365
+ if axis is None :
366
+ return partition (xp .reshape (a , - 1 ), kth , axis = 0 , xp = xp )
367
+ size = a .shape [axis ]
368
+ if size is None :
369
+ msg = "Array dimensions must be known"
370
+ raise ValueError (msg )
371
+ if not (0 <= kth < size ):
372
+ msg = f"kth(={ kth } ) out of bounds [0 { size } )"
373
+ raise ValueError (msg )
360
374
361
375
# Delegate where possible.
362
- if is_numpy_namespace (xp ) or is_cupy_namespace (xp ):
363
- return xp .partition (a , kth )
364
- if is_jax_namespace (xp ):
365
- from jax import numpy
366
-
367
- return numpy .partition (a , kth )
376
+ if is_numpy_namespace (xp ) or is_cupy_namespace (xp ) or is_jax_namespace (xp ):
377
+ return xp .partition (a , kth , axis = axis )
368
378
369
379
# Use top-k when possible:
370
380
if is_torch_namespace (xp ):
371
- from torch import topk
381
+ if not (axis == - 1 or axis == a .ndim - 1 ):
382
+ a = xp .transpose (a , axis , - 1 )
372
383
373
- a_left , indices_left = topk (a , kth , largest = False , sorted = False )
384
+ # Get smallest `kth` elements along axis
385
+ kth += 1 # HACK: we use a non-specified behavior of torch.topk:
386
+ # in `a_left`, the element in the last position is the max
387
+ a_left , indices = xp .topk (a , kth , dim = - 1 , largest = False , sorted = False )
388
+
389
+ # Build a mask to remove the selected elements
374
390
mask_right = xp .ones (a .shape , dtype = bool )
375
- mask_right [indices_left ] = False
376
- return xp .concat ((a_left , a [mask_right ]))
391
+ mask_right .scatter_ (dim = - 1 , index = indices , value = False )
392
+
393
+ # Remaining elements along axis
394
+ a_right = a [mask_right ] # 1-d array
395
+
396
+ # Reshape. This is valid only because we work on the last axis
397
+ a_right = xp .reshape (a_right , shape = (* a .shape [:- 1 ], - 1 ))
398
+
399
+ # Concatenate the two parts along axis
400
+ partitioned_array = xp .cat ((a_left , a_right ), dim = - 1 )
401
+ if not (axis == - 1 or axis == a .ndim - 1 ):
402
+ partitioned_array = xp .transpose (partitioned_array , axis , - 1 )
403
+ return partitioned_array
404
+
377
405
# Note: dask topk/argtopk sort the return values, so it's
378
406
# not much more efficient than sorting everything when
379
407
# kth is not small compared to x.size
380
408
381
- return _funcs .partition (a , kth , xp = xp )
409
+ return _funcs .partition (a , kth , axis = axis , xp = xp )
382
410
383
411
384
412
def argpartition (
385
413
a : Array ,
386
414
kth : int ,
415
+ / ,
416
+ axis : int | None = - 1 ,
387
417
* ,
388
418
xp : ModuleType | None = None ,
389
419
) -> Array :
@@ -392,10 +422,13 @@ def argpartition(
392
422
393
423
Parameters
394
424
----------
395
- a : 1-dimensional array
425
+ a : Array
396
426
Input array.
397
427
kth : int
398
428
Element index to partition by.
429
+ axis : int, optional
430
+ Axis along which to partition. The default is -1 (the last axis).
431
+ If None, the flattened array is used.
399
432
xp : array_namespace, optional
400
433
The standard-compatible namespace for `x`. Default: infer.
401
434
@@ -407,29 +440,46 @@ def argpartition(
407
440
# Validate inputs.
408
441
if xp is None :
409
442
xp = array_namespace (a )
410
- if a .ndim != 1 :
411
- msg = "only 1-dimensional arrays are currently supported"
412
- raise NotImplementedError (msg )
443
+ if a .ndim < 1 :
444
+ msg = "`a` must be at least 1-dimensional"
445
+ raise TypeError (msg )
446
+ if axis is None :
447
+ return partition (xp .reshape (a , - 1 ), kth , axis = 0 , xp = xp )
448
+ size = a .shape [axis ]
449
+ if size is None :
450
+ msg = "Array dimensions must be known"
451
+ raise ValueError (msg )
452
+ if not (0 <= kth < size ):
453
+ msg = f"kth(={ kth } ) out of bounds [0 { size } )"
454
+ raise ValueError (msg )
413
455
414
456
# Delegate where possible.
415
- if is_numpy_namespace (xp ) or is_cupy_namespace (xp ):
416
- return xp .argpartition (a , kth )
417
- if is_jax_namespace (xp ):
418
- from jax import numpy
419
-
420
- return numpy .argpartition (a , kth )
457
+ if is_numpy_namespace (xp ) or is_cupy_namespace (xp ) or is_jax_namespace (xp ):
458
+ return xp .argpartition (a , kth , axis = axis )
421
459
422
460
# Use top-k when possible:
423
461
if is_torch_namespace (xp ):
424
- from torch import topk
462
+ # see `partition` above for commented details of those steps:
463
+ if not (axis == - 1 or axis == a .ndim - 1 ):
464
+ a = xp .transpose (a , axis , - 1 )
465
+
466
+ kth += 1 # HACK
467
+ _ , indices_left = xp .topk (a , kth , dim = - 1 , largest = False , sorted = False )
468
+
469
+ mask_right = xp .ones (a .shape , dtype = bool )
470
+ mask_right .scatter_ (dim = - 1 , index = indices_left , value = False )
471
+
472
+ indices_right = xp .nonzero (mask_right )[- 1 ]
473
+ indices_right = xp .reshape (indices_right , shape = (* a .shape [:- 1 ], - 1 ))
474
+
475
+ # Concatenate the two parts along axis
476
+ index_array = xp .cat ((indices_left , indices_right ), dim = - 1 )
477
+ if not (axis == - 1 or axis == a .ndim - 1 ):
478
+ index_array = xp .transpose (index_array , axis , - 1 )
479
+ return index_array
425
480
426
- _ , indices = topk (a , kth , largest = False , sorted = False )
427
- mask = xp .ones (a .shape , dtype = bool )
428
- mask [indices ] = False
429
- indices_above = xp .arange (a .shape [0 ])[mask ]
430
- return xp .concat ((indices , indices_above ))
431
481
# Note: dask topk/argtopk sort the return values, so it's
432
482
# not much more efficient than sorting everything when
433
483
# kth is not small compared to x.size
434
484
435
- return _funcs .argpartition (a , kth , xp = xp )
485
+ return _funcs .argpartition (a , kth , axis = axis , xp = xp )
0 commit comments