File tree Expand file tree Collapse file tree 2 files changed +11
-3
lines changed Expand file tree Collapse file tree 2 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -1285,15 +1285,17 @@ defmodule EXLA.Defn do
1285
1285
tensor_shape = op_shape ( tensor )
1286
1286
tensor_rank = tuple_size ( tensor_shape )
1287
1287
tensor_axes = axes_for_rank ( tensor_rank )
1288
- index_vector_dim = tuple_size ( op_shape ( indices ) ) - 1
1288
+ indices_rank = tuple_size ( op_shape ( indices ) )
1289
+ index_vector_dim = indices_rank - 1
1289
1290
1290
1291
slice_sizes =
1291
1292
for i <- tensor_axes do
1292
1293
if i in axes , do: 1 , else: elem ( tensor_shape , i )
1293
1294
end
1294
1295
1295
1296
batch_size = tensor_rank - length ( axes )
1296
- offset_dims = count_up ( batch_size , batch_size )
1297
+ offset_size = indices_rank - length ( axes )
1298
+ offset_dims = count_up ( batch_size , offset_size )
1297
1299
1298
1300
Value . gather (
1299
1301
tensor ,
Original file line number Diff line number Diff line change @@ -7876,6 +7876,8 @@ defmodule Nx do
7876
7876
end
7877
7877
7878
7878
defp indexed_axes ( tensor , indices , opts ) do
7879
+ n = elem ( indices . shape , tuple_size ( indices . shape ) - 1 )
7880
+
7879
7881
if axes = opts [ :axes ] do
7880
7882
axes = Nx.Shape . normalize_axes ( tensor . shape , axes , tensor . names )
7881
7883
@@ -7884,9 +7886,13 @@ defmodule Nx do
7884
7886
_ , _ -> raise ArgumentError , ":axes must be an ordered list"
7885
7887
end )
7886
7888
7889
+ if length ( axes ) != n do
7890
+ raise ArgumentError ,
7891
+ ":axes must have the same number of elements as the last dimension of indices"
7892
+ end
7893
+
7887
7894
axes
7888
7895
else
7889
- n = elem ( indices . shape , tuple_size ( indices . shape ) - 1 )
7890
7896
Enum . to_list ( 0 .. ( n - 1 ) )
7891
7897
end
7892
7898
end
You can’t perform that action at this time.
0 commit comments