Skip to content

Commit 24ed6fa

Browse files
Remove the Nx.take backend callback (#1439)
1 parent ad45733 commit 24ed6fa

File tree

7 files changed

+10
-129
lines changed

7 files changed

+10
-129
lines changed

exla/lib/exla/backend.ex

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ defmodule EXLA.Backend do
325325
{:reverse, [:tensor, :axes], [:tensor]},
326326
{:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]},
327327
{:clip, [:tensor, :min, :max], [:tensor, :min, :max]},
328-
{:take, [:tensor, :indices, :axis], [:tensor, :indices]},
329328
{:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]},
330329
{:gather, [:input, :indices, :opts], [:input, :indices]},
331330
{:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]},

exla/lib/exla/defn.ex

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,29 +1220,6 @@ defmodule EXLA.Defn do
12201220
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
12211221
end
12221222

1223-
defp to_operator(:take, [%Value{} = tensor, indices, axis], ans, _state) do
1224-
tensor_rank = tensor |> op_shape() |> tuple_size()
1225-
indices_rank = indices |> op_shape() |> tuple_size()
1226-
result_rank = tensor_rank - 1 + indices_rank
1227-
1228-
index_vector_dim = indices_rank
1229-
slice_sizes = tensor |> op_shape() |> put_elem(axis, 1) |> Tuple.to_list()
1230-
offset_dims = result_rank |> axes_for_rank() |> delete_slice(axis, indices_rank)
1231-
collapsed_slice_dims = [axis]
1232-
start_index_map = [axis]
1233-
1234-
Value.gather(
1235-
tensor,
1236-
indices,
1237-
index_vector_dim,
1238-
slice_sizes,
1239-
offset_dims,
1240-
collapsed_slice_dims,
1241-
start_index_map,
1242-
expr_to_typespec(ans)
1243-
)
1244-
end
1245-
12461223
defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], ans, state) do
12471224
%{shape: indices_shape} = indices_typespec = Value.get_typespec(indices)
12481225
indices_rank = tuple_size(indices_shape)
@@ -1962,11 +1939,6 @@ defmodule EXLA.Defn do
19621939

19631940
# Helpers
19641941

1965-
defp delete_slice(enumerable, index, length) do
1966-
{left, right} = Enum.split(enumerable, index)
1967-
left ++ Enum.drop(right, length)
1968-
end
1969-
19701942
defp apply_mlir_broadcasted_bin_op(op, out, left, right) do
19711943
left_typespec = Value.get_typespec(left)
19721944
right_typespec = Value.get_typespec(right)

nx/lib/nx.ex

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14130,13 +14130,17 @@ defmodule Nx do
1413014130
else
1413114131
tensor = devectorize(tensor, keep_names: false)
1413214132
indices = devectorize(indices, keep_names: false)
14133+
gather_indices = new_axis(indices, rank(indices))
1413314134

14134-
impl!(tensor).take(
14135-
%{tensor | shape: inner_shape, names: inner_names},
14136-
tensor,
14137-
indices,
14138-
axis
14139-
)
14135+
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
14136+
{leading, trailing} = Enum.split(tensor_axes, axis)
14137+
14138+
transpose_axes = leading ++ indices_axes ++ trailing
14139+
14140+
tensor
14141+
|> gather(gather_indices, axes: [axis])
14142+
|> transpose(axes: transpose_axes)
14143+
|> reshape(inner_shape, names: inner_names)
1414014144
end
1414114145
end
1414214146

nx/lib/nx/backend.ex

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ defmodule Nx.Backend do
7373
@callback clip(out :: tensor, tensor, min :: tensor, max :: tensor) :: tensor
7474
@callback slice(out :: tensor, tensor, list, list, list) :: tensor
7575
@callback put_slice(out :: tensor, tensor, tensor, list) :: tensor
76-
@callback take(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
7776
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
7877
@callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
7978
@callback concatenate(out :: tensor, tensor, axis) :: tensor

nx/lib/nx/binary_backend.ex

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,47 +1939,6 @@ defmodule Nx.BinaryBackend do
19391939
from_binary(out, data)
19401940
end
19411941

1942-
@impl true
1943-
def take(out, tensor, indices, axis) do
1944-
# We iterate over the indices in a flat manner,
1945-
# and take a unit tensor slice along axis given
1946-
# by each index. Then we concatenate the tensors
1947-
# along the axis, which gives us the result with
1948-
# index dimensions flattened and we just reshape.
1949-
1950-
%T{type: {_, size}, shape: shape} = tensor
1951-
%T{type: {_, idx_size}} = indices
1952-
1953-
data = to_binary(tensor)
1954-
tensor_rank = tuple_size(shape)
1955-
slice_start = List.duplicate(0, tensor_rank)
1956-
slice_lengths = shape |> Tuple.to_list() |> List.replace_at(axis, 1)
1957-
slice_shape = List.to_tuple(slice_lengths)
1958-
strides = List.duplicate(1, tensor_rank)
1959-
1960-
slices =
1961-
for <<bin::size(idx_size)-bitstring <- to_binary(indices)>> do
1962-
idx = binary_to_number(bin, indices.type)
1963-
1964-
if idx < 0 or idx >= elem(shape, axis) do
1965-
raise ArgumentError,
1966-
"index #{idx} is out of bounds for axis #{axis} in shape #{inspect(shape)}"
1967-
end
1968-
1969-
slice_start = List.replace_at(slice_start, axis, idx)
1970-
1971-
slice_data =
1972-
bin_slice(data, shape, size, slice_start, slice_lengths, strides, slice_shape)
1973-
1974-
{slice_data, slice_shape}
1975-
end
1976-
1977-
concat_shape = put_elem(tensor.shape, axis, length(slices))
1978-
result_data = bin_concatenate(slices, size, axis, concat_shape)
1979-
1980-
from_binary(out, result_data)
1981-
end
1982-
19831942
@impl true
19841943
def take_along_axis(
19851944
%T{type: output_type} = output,

nx/lib/nx/defn/expr.ex

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,12 +1183,6 @@ defmodule Nx.Defn.Expr do
11831183
expr(out, context, :put_slice, [tensor, start, slice])
11841184
end
11851185

1186-
@impl true
1187-
def take(out, tensor, indices, axis) do
1188-
{[tensor, indices], context} = to_exprs([tensor, indices])
1189-
expr(out, context, :take, [tensor, indices, axis])
1190-
end
1191-
11921186
@impl true
11931187
def take_along_axis(out, tensor, indices, axis) do
11941188
{[tensor, indices], context} = to_exprs([tensor, indices])

torchx/lib/torchx/backend.ex

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -337,52 +337,6 @@ defmodule Torchx.Backend do
337337
|> to_nx(out)
338338
end
339339

340-
@impl true
341-
def take(out, t, i, axis) do
342-
axes = Nx.axes(t)
343-
344-
indices_shape =
345-
axes
346-
|> Enum.map(fn
347-
^axis -> Tuple.product(i.shape)
348-
_ -> 1
349-
end)
350-
|> List.to_tuple()
351-
352-
idx_tiling =
353-
t.shape
354-
|> Tuple.to_list()
355-
|> Enum.with_index(fn
356-
_x, ^axis -> 1
357-
x, _ -> x
358-
end)
359-
360-
indices_for_axis =
361-
i
362-
|> Nx.reshape(indices_shape)
363-
|> Nx.tile(idx_tiling)
364-
365-
num_elements = Tuple.product(indices_for_axis.shape)
366-
367-
indices =
368-
axes
369-
|> Enum.map(fn
370-
^axis ->
371-
Nx.reshape(indices_for_axis, {num_elements, 1})
372-
373-
current ->
374-
# current when current < axis ->
375-
indices_for_axis
376-
|> Nx.shape()
377-
|> Nx.iota(axis: current, backend: __MODULE__)
378-
|> Nx.reshape({num_elements, 1})
379-
end)
380-
|> Nx.concatenate(axis: 1)
381-
382-
# TODO: maybe rewrite it as gather now behaves differently
383-
gather(out, t, indices, [])
384-
end
385-
386340
@impl true
387341
def gather(out, tensor, indices, opts) do
388342
tensor_axes = Nx.axes(tensor)

0 commit comments

Comments
 (0)