Skip to content

Commit c88e93e

Browse files
committed
Docstring improvements
1 parent 259f93d commit c88e93e

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

src/array_api_extra/_delegation.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,22 +339,33 @@ def partition(
339339
"""
340340
Return a partitioned copy of an array.
341341
342-
Parameters
342+
Creates a copy of the array and partially sorts it in such a way that the value
343+
of the element in k-th position is in the position it would be in a sorted array.
344+
In the output array, all elements smaller than the k-th element are located to
345+
the left of this element and all equal or greater are located to its right.
346+
The ordering of the elements in the two partitions on the either side of
347+
the k-th element in the output array is undefined.
348+
343349
----------
344-
a : 1-dimensional array
350+
a : Array
345351
Input array.
346352
kth : int
347353
Element index to partition by.
348354
axis : int, optional
349-
Axis along which to partition. The default is -1 (the last axis).
350-
If None, the flattened array is used.
355+
Axis along which to partition. The default is ``-1`` (the last axis).
356+
If ``None``, the flattened array is used.
351357
xp : array_namespace, optional
352358
The standard-compatible namespace for `x`. Default: infer.
353359
354360
Returns
355361
-------
356362
partitioned_array
357-
Array of the same type and shape as a.
363+
Array of the same type and shape as `a`.
364+
365+
Notes:
366+
If `xp` implements `partition` or an equivalent method (e.g. topk for torch),
367+
complexity will likely be O(n).
368+
If not, this function simply calls `xp.sort` and complexity is O(n log n).
358369
"""
359370
# Validate inputs.
360371
if xp is None:
@@ -416,6 +427,8 @@ def argpartition(
416427
) -> Array:
417428
"""
418429
Perform an indirect partition along the given axis.
430+
It returns an array of indices of the same shape as `a` that
431+
index data along the given axis in partitioned order.
419432
420433
Parameters
421434
----------
@@ -424,15 +437,20 @@ def argpartition(
424437
kth : int
425438
Element index to partition by.
426439
axis : int, optional
427-
Axis along which to partition. The default is -1 (the last axis).
428-
If None, the flattened array is used.
440+
Axis along which to partition. The default is ``-1`` (the last axis).
441+
If ``None``, the flattened array is used.
429442
xp : array_namespace, optional
430443
The standard-compatible namespace for `x`. Default: infer.
431444
432445
Returns
433446
-------
434447
index_array
435448
Array of indices that partition `a` along the specified axis.
449+
450+
Notes:
451+
If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch),
452+
complexity will likely be O(n).
453+
If not, this function simply calls `xp.argsort` and complexity is O(n log n).
436454
"""
437455
# Validate inputs.
438456
if xp is None:

0 commit comments

Comments
 (0)