Skip to content

Commit ad45733

Browse files
committed
Fix bugs with gather
1 parent 73c5417 commit ad45733

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

exla/lib/exla/defn.ex

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,15 +1285,17 @@ defmodule EXLA.Defn do
12851285
tensor_shape = op_shape(tensor)
12861286
tensor_rank = tuple_size(tensor_shape)
12871287
tensor_axes = axes_for_rank(tensor_rank)
1288-
index_vector_dim = tuple_size(op_shape(indices)) - 1
1288+
indices_rank = tuple_size(op_shape(indices))
1289+
index_vector_dim = indices_rank - 1
12891290

12901291
slice_sizes =
12911292
for i <- tensor_axes do
12921293
if i in axes, do: 1, else: elem(tensor_shape, i)
12931294
end
12941295

12951296
batch_size = tensor_rank - length(axes)
1296-
offset_dims = count_up(batch_size, batch_size)
1297+
offset_size = indices_rank - length(axes)
1298+
offset_dims = count_up(batch_size, offset_size)
12971299

12981300
Value.gather(
12991301
tensor,

nx/lib/nx.ex

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7876,6 +7876,8 @@ defmodule Nx do
78767876
end
78777877

78787878
defp indexed_axes(tensor, indices, opts) do
7879+
n = elem(indices.shape, tuple_size(indices.shape) - 1)
7880+
78797881
if axes = opts[:axes] do
78807882
axes = Nx.Shape.normalize_axes(tensor.shape, axes, tensor.names)
78817883

@@ -7884,9 +7886,13 @@ defmodule Nx do
78847886
_, _ -> raise ArgumentError, ":axes must be an ordered list"
78857887
end)
78867888

7889+
if length(axes) != n do
7890+
raise ArgumentError,
7891+
":axes must have the same number of elements as the last dimension of indices"
7892+
end
7893+
78877894
axes
78887895
else
7889-
n = elem(indices.shape, tuple_size(indices.shape) - 1)
78907896
Enum.to_list(0..(n - 1))
78917897
end
78927898
end

0 commit comments

Comments
 (0)