Skip to content

Commit 7c36e06

Browse files
Benjamin-Philipjosevalim
authored andcommitted
Make take_along_axis an optional callback
Closes #1440.
1 parent 6139d2a commit 7c36e06

File tree

8 files changed

+26
-131
lines changed

8 files changed

+26
-131
lines changed

exla/lib/exla/backend.ex

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ defmodule EXLA.Backend do
325325
{:reverse, [:tensor, :axes], [:tensor]},
326326
{:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]},
327327
{:clip, [:tensor, :min, :max], [:tensor, :min, :max]},
328-
{:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]},
329328
{:gather, [:input, :indices, :opts], [:input, :indices]},
330329
{:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]},
331330
{:conv, [:tensor, :kernel, :opts], [:tensor, :kernel]},

exla/lib/exla/defn.ex

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,43 +1259,6 @@ defmodule EXLA.Defn do
12591259
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
12601260
end
12611261

1262-
defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], ans, state) do
1263-
%{shape: indices_shape} = indices_typespec = Value.get_typespec(indices)
1264-
indices_rank = tuple_size(indices_shape)
1265-
1266-
axes_range = 0..(indices_rank - 1)//1
1267-
1268-
index_vector_dim = indices_rank
1269-
slice_sizes = List.duplicate(1, indices_rank)
1270-
offset_dims = []
1271-
collapsed_slice_dims = Enum.to_list(axes_range)
1272-
start_index_map = Enum.to_list(axes_range)
1273-
1274-
new_axis_typespec = Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, 1))
1275-
1276-
full_indices_typespec =
1277-
Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, indices_rank))
1278-
1279-
full_indices =
1280-
axes_range
1281-
|> Enum.map(fn
1282-
^axis -> Value.reshape(indices, new_axis_typespec)
1283-
axis -> Value.iota(state.builder, axis, new_axis_typespec)
1284-
end)
1285-
|> Value.concatenate(indices_rank, full_indices_typespec)
1286-
1287-
Value.gather(
1288-
tensor,
1289-
full_indices,
1290-
index_vector_dim,
1291-
slice_sizes,
1292-
offset_dims,
1293-
collapsed_slice_dims,
1294-
start_index_map,
1295-
expr_to_typespec(ans)
1296-
)
1297-
end
1298-
12991262
defp to_operator(:gather, [%Value{} = tensor, indices, opts], ans, _state) do
13001263
axes = Keyword.fetch!(opts, :axes)
13011264
tensor_shape = op_shape(tensor)

nx/lib/nx.ex

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14142,7 +14142,7 @@ defmodule Nx do
1414214142
tensor
1414314143
|> gather(gather_indices, axes: [axis])
1414414144
|> transpose(axes: transpose_axes)
14145-
|> reshape(inner_shape, names: inner_names)
14145+
|> rename(inner_names)
1414614146
end)
1414714147
end
1414814148
end
@@ -14302,17 +14302,32 @@ defmodule Nx do
1430214302
end
1430314303

1430414304
opts = keyword!(opts, axis: 0)
14305-
1430614305
tensor = devectorize(tensor, keep_names: false)
1430714306
indices = devectorize(indices, keep_names: false)
14308-
1430914307
offset = length(vectorized_axes)
1431014308

1431114309
axis = Nx.Shape.normalize_axis(tensor.shape, opts[:axis], tensor.names, offset)
14312-
1431314310
shape = Nx.Shape.take_along_axis(tensor.shape, indices.shape, axis)
14311+
out = %{tensor | shape: shape}
1431414312

14315-
result = impl!(tensor).take_along_axis(%{tensor | shape: shape}, tensor, indices, axis)
14313+
result =
14314+
Nx.Shared.optional(:take_along_axis, [tensor, indices, [axis: axis]], out, fn
14315+
tensor, indices, _opts ->
14316+
axes_range = axes(indices)
14317+
new_axis_shape = Tuple.append(shape(indices), 1)
14318+
14319+
full_indices =
14320+
axes_range
14321+
|> Enum.map(fn
14322+
^axis -> reshape(indices, new_axis_shape)
14323+
axis -> iota(new_axis_shape, axis: axis)
14324+
end)
14325+
|> concatenate(axis: rank(indices))
14326+
14327+
tensor
14328+
|> gather(full_indices)
14329+
|> rename(tensor.names)
14330+
end)
1431614331

1431714332
vectorize(result, vectorized_axes)
1431814333
end

nx/lib/nx/backend.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ defmodule Nx.Backend do
7373
@callback clip(out :: tensor, tensor, min :: tensor, max :: tensor) :: tensor
7474
@callback slice(out :: tensor, tensor, list, list, list) :: tensor
7575
@callback put_slice(out :: tensor, tensor, tensor, list) :: tensor
76-
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
7776
@callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
7877
@callback concatenate(out :: tensor, tensor, axis) :: tensor
7978
@callback select(out :: tensor, tensor, tensor, tensor) :: tensor
@@ -159,6 +158,7 @@ defmodule Nx.Backend do
159158
@callback all_close(out :: tensor, tensor, tensor, keyword) :: tensor
160159
@callback top_k(out :: tensor, tensor, keyword) :: tensor
161160
@callback take(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
161+
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
162162

163163
@optional_callbacks [
164164
optional: 3,
@@ -178,7 +178,8 @@ defmodule Nx.Backend do
178178
qr: 3,
179179
cholesky: 2,
180180
eigh: 3,
181-
take: 4
181+
take: 4,
182+
take_along_axis: 4
182183
]
183184

184185
## Inspect implementation

nx/lib/nx/binary_backend.ex

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,43 +1939,6 @@ defmodule Nx.BinaryBackend do
19391939
from_binary(out, data)
19401940
end
19411941

1942-
@impl true
1943-
def take_along_axis(
1944-
%T{type: output_type} = output,
1945-
%T{shape: t_shape, type: {_, t_size} = t_type} = tensor,
1946-
%T{shape: idx_shape, type: {_, idx_size} = idx_type} = indices,
1947-
axis
1948-
) do
1949-
permutation =
1950-
tensor
1951-
|> Nx.axes()
1952-
|> List.delete(axis)
1953-
|> List.insert_at(Nx.rank(tensor) - 1, axis)
1954-
1955-
inverse_permutation = inverse_permutation(permutation)
1956-
shape_list = Tuple.to_list(output.shape)
1957-
permuted_shape = permutation |> Enum.map(&Enum.at(shape_list, &1)) |> List.to_tuple()
1958-
1959-
t_view = tensor |> to_binary() |> aggregate_axes([axis], t_shape, t_size)
1960-
idx_view = indices |> to_binary() |> aggregate_axes([axis], idx_shape, idx_size)
1961-
1962-
[t_view, idx_view]
1963-
|> Enum.zip_with(fn [data_bin, idx_bin] ->
1964-
data = binary_to_list(data_bin, t_type)
1965-
1966-
binary_to_binary(idx_bin, idx_type, output_type, fn idx ->
1967-
if idx < 0 or idx >= elem(tensor.shape, axis) do
1968-
raise ArgumentError,
1969-
"index #{idx} is out of bounds for axis #{axis} in shape #{inspect(tensor.shape)}"
1970-
end
1971-
1972-
Enum.at(data, idx)
1973-
end)
1974-
end)
1975-
|> then(&from_binary(%{output | shape: permuted_shape}, &1))
1976-
|> then(&transpose(output, &1, inverse_permutation))
1977-
end
1978-
19791942
@impl true
19801943
def gather(out, tensor, indices, opts) do
19811944
axes = opts[:axes]

nx/lib/nx/defn/expr.ex

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,12 +1183,6 @@ defmodule Nx.Defn.Expr do
11831183
expr(out, context, :put_slice, [tensor, start, slice])
11841184
end
11851185

1186-
@impl true
1187-
def take_along_axis(out, tensor, indices, axis) do
1188-
{[tensor, indices], context} = to_exprs([tensor, indices])
1189-
expr(out, context, :take_along_axis, [tensor, indices, axis])
1190-
end
1191-
11921186
@impl true
11931187
def gather(out, tensor, indices, opts) do
11941188
{[tensor, indices], context} = to_exprs([tensor, indices])

nx/lib/nx/defn/grad.ex

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ defmodule Nx.Defn.Grad do
150150
defp reduce_args(:put_slice, %{data: %{args: [arg, _, update | _]}}, acc, fun),
151151
do: fun.(arg, fun.(update, acc))
152152

153-
defp reduce_args(:take_along_axis, %{data: %{args: [arg | _]}}, acc, fun),
154-
do: fun.(arg, acc)
155-
156153
defp reduce_args(:gather, %{data: %{args: [arg | _]}}, acc, fun),
157154
do: fun.(arg, acc)
158155

@@ -663,44 +660,6 @@ defmodule Nx.Defn.Grad do
663660
[{t, g}]
664661
end
665662

666-
defp grad(:take_along_axis, [t, i, axis], _ans, g) do
667-
num_elements = i |> Nx.shape() |> Tuple.product()
668-
669-
# Convert `i`, the take_along_axis indices, to a list of
670-
# fully qualified (i.e. [0, 2, 1] for a {_, _, _}-shaped tensor)
671-
# indices
672-
673-
indices =
674-
0..(Nx.rank(g) - 1)//1
675-
|> Enum.map(fn
676-
# For the axis of interest, we'll use the actual take_along_axis indices
677-
^axis ->
678-
Nx.reshape(i, {num_elements, 1})
679-
680-
axis ->
681-
i
682-
|> Nx.shape()
683-
|> Nx.iota(axis: axis)
684-
|> Nx.reshape({num_elements, 1})
685-
end)
686-
|> Nx.concatenate(axis: 1)
687-
688-
# Since g is produced through the given indices,
689-
# we can reshape g to be a {num_elements} shaped tensor
690-
# which will directly correspond to each of the reshaped
691-
# indices above
692-
updates = Nx.reshape(g, {num_elements})
693-
694-
# The intuition for this grad is that for each index taken, we'll
695-
# add the corresponding result grad to the original
696-
g =
697-
t
698-
|> Expr.broadcast(0, Nx.shape(t), Nx.axes(t))
699-
|> Nx.indexed_add(indices, updates)
700-
701-
[{t, g}]
702-
end
703-
704663
defp grad(:gather, [t, i, opts], _ans, g) do
705664
i_axes = opts[:axes]
706665
i_shape = i.shape
@@ -714,6 +673,7 @@ defmodule Nx.Defn.Grad do
714673

715674
g =
716675
0
676+
|> Nx.as_type(t.type)
717677
|> Nx.broadcast(t_shape)
718678
|> Nx.indexed_add(indices, updates, opts)
719679

torchx/lib/torchx/backend.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,12 @@ defmodule Torchx.Backend do
425425
end
426426

427427
@impl true
428-
def take_along_axis(out, tensor, idx, axis) do
428+
def take_along_axis(out, tensor, idx, opts) do
429429
idx_tx = idx |> from_nx() |> Torchx.to_type(:long)
430430

431431
tensor
432432
|> from_nx()
433-
|> Torchx.gather(idx_tx, axis)
433+
|> Torchx.gather(idx_tx, opts[:axis])
434434
|> to_nx(out)
435435
end
436436

0 commit comments

Comments
 (0)