@@ -14132,7 +14132,7 @@ defmodule Nx do
14132
14132
indices = devectorize ( indices , keep_names: false )
14133
14133
out = % { tensor | shape: inner_shape , names: inner_names }
14134
14134
14135
- Nx.Shared . optional ( :take , [ tensor , indices , [ axis: axis ] ] , out , fn tensor , indices , _opts ->
14135
+ Nx.Shared . optional ( :take , [ tensor , indices , axis ] , out , fn tensor , indices , axis ->
14136
14136
gather_indices = new_axis ( indices , rank ( indices ) )
14137
14137
{ indices_axes , tensor_axes } = Enum . split ( axes ( inner_shape ) , rank ( indices ) )
14138
14138
{ leading , trailing } = Enum . split ( tensor_axes , axis )
@@ -14321,12 +14321,20 @@ defmodule Nx do
14321
14321
Builds a new tensor by taking individual values from the original
14322
14322
tensor at the given indices.
14323
14323
14324
+ Indices must be a tensor where the last dimension is usually of the
14325
+ same size as the `tensor` rank. Each entry in `indices` will be
14326
+ part of the results. If the last dimension of indices is less than
14327
+ the `tensor` rank, then a multidimensional tensor is gathered and
14328
+ spliced into the result.
14329
+
14324
14330
## Options
14325
14331
14326
- * `:axes` - controls which dimensions the indexes apply to.
14327
- It must be a sorted list of axes and be of the same size
14328
- as the second (last) dimension of the indexes tensor.
14329
- It defaults to the leading axes of the tensor.
14332
+ * `:axes` - controls to which dimensions of `tensor`
14333
+ each element in the last dimension of `indexes` applies to.
14334
+ It defaults so the first element in indexes apply to the first
14335
+ axis, the second to the second, and so on. It must be a sorted
14336
+ list of axes and be of the same size as the last dimension of
14337
+ the indexes tensor.
14330
14338
14331
14339
## Examples
14332
14340
0 commit comments