@@ -339,22 +339,33 @@ def partition(
339
339
"""
340
340
Return a partitioned copy of an array.
341
341
342
- Parameters
342
+ Creates a copy of the array and partially sorts it in such a way that the value
343
+ of the element in k-th position is in the position it would be in a sorted array.
344
+ In the output array, all elements smaller than the k-th element are located to
345
+ the left of this element and all equal or greater are located to its right.
346
+ The ordering of the elements in the two partitions on the either side of
347
+ the k-th element in the output array is undefined.
348
+
343
349
----------
344
- a : 1-dimensional array
350
+ a : Array
345
351
Input array.
346
352
kth : int
347
353
Element index to partition by.
348
354
axis : int, optional
349
- Axis along which to partition. The default is -1 (the last axis).
350
- If None, the flattened array is used.
355
+ Axis along which to partition. The default is ``-1`` (the last axis).
356
+ If `` None`` , the flattened array is used.
351
357
xp : array_namespace, optional
352
358
The standard-compatible namespace for `x`. Default: infer.
353
359
354
360
Returns
355
361
-------
356
362
partitioned_array
357
- Array of the same type and shape as a.
363
+ Array of the same type and shape as `a`.
364
+
365
+ Notes:
366
+ If `xp` implements `partition` or an equivalent method (e.g. topk for torch),
367
+ complexity will likely be O(n).
368
+ If not, this function simply calls `xp.sort` and complexity is O(n log n).
358
369
"""
359
370
# Validate inputs.
360
371
if xp is None :
@@ -416,6 +427,8 @@ def argpartition(
416
427
) -> Array :
417
428
"""
418
429
Perform an indirect partition along the given axis.
430
+ It returns an array of indices of the same shape as `a` that
431
+ index data along the given axis in partitioned order.
419
432
420
433
Parameters
421
434
----------
@@ -424,15 +437,20 @@ def argpartition(
424
437
kth : int
425
438
Element index to partition by.
426
439
axis : int, optional
427
- Axis along which to partition. The default is -1 (the last axis).
428
- If None, the flattened array is used.
440
+ Axis along which to partition. The default is ``-1`` (the last axis).
441
+ If `` None`` , the flattened array is used.
429
442
xp : array_namespace, optional
430
443
The standard-compatible namespace for `x`. Default: infer.
431
444
432
445
Returns
433
446
-------
434
447
index_array
435
448
Array of indices that partition `a` along the specified axis.
449
+
450
+ Notes:
451
+ If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch),
452
+ complexity will likely be O(n).
453
+ If not, this function simply calls `xp.argsort` and complexity is O(n log n).
436
454
"""
437
455
# Validate inputs.
438
456
if xp is None :
0 commit comments