Skip to content

Commit 2333841

Browse files
Maalvi14awni
andauthored
Improved mx.split() docs (#2689)
* Improved mx.split() documentation * Fix typo in docstring for array split function * add example --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 5bcf3a6 commit 2333841

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

python/src/ops.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,13 +2577,23 @@ void init_ops(nb::module_& m) {
25772577
a (array): Input array.
25782578
indices_or_sections (int or list(int)): If ``indices_or_sections``
25792579
is an integer the array is split into that many sections of equal
2580-
size. An error is raised if this is not possible. If ``indices_or_sections``
2581-
is a list, the list contains the indices of the start of each subarray
2582-
along the given axis.
2580+
size. An error is raised if this is not possible. If
2581+
``indices_or_sections`` is a list, then the indices are the split
2582+
points, and the array is divided into
2583+
``len(indices_or_sections) + 1`` sub-arrays.
25832584
axis (int, optional): Axis to split along, defaults to `0`.
25842585
25852586
Returns:
25862587
list(array): A list of split arrays.
2588+
2589+
Example:
2590+
2591+
>>> a = mx.array([1, 2, 3, 4], dtype=mx.int32)
2592+
>>> mx.split(a, 2)
2593+
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
2594+
>>> mx.split(a, [1, 3])
2595+
[array([1], dtype=int32), array([2, 3], dtype=int32), array([4], dtype=int32)]
2596+
25872597
)pbdoc");
25882598
m.def(
25892599
"argmin",

0 commit comments

Comments
 (0)