Skip to content

Commit 50ecf0a

Browse files
committed
Remove other uses of :take
1 parent 24ed6fa commit 50ecf0a

File tree

1 file changed

+0
-66
lines changed

1 file changed

+0
-66
lines changed

nx/lib/nx/defn/grad.ex

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,6 @@ defmodule Nx.Defn.Grad do
153153
defp reduce_args(:take_along_axis, %{data: %{args: [arg | _]}}, acc, fun),
154154
do: fun.(arg, acc)
155155

156-
defp reduce_args(:take, %{data: %{args: [arg | _]}}, acc, fun),
157-
do: fun.(arg, acc)
158-
159156
defp reduce_args(:gather, %{data: %{args: [arg | _]}}, acc, fun),
160157
do: fun.(arg, acc)
161158

@@ -704,69 +701,6 @@ defmodule Nx.Defn.Grad do
704701
[{t, g}]
705702
end
706703

707-
defp grad(:take, [t, i, axis], _ans, g) do
708-
axes_range = 0..(Nx.rank(t) - 1)//1
709-
710-
indices_shape =
711-
axes_range
712-
|> Enum.flat_map(fn
713-
^axis -> Tuple.to_list(i.shape)
714-
_ -> [1]
715-
end)
716-
|> List.to_tuple()
717-
718-
idx_tiling =
719-
t.shape
720-
|> Tuple.to_list()
721-
|> Enum.with_index(fn
722-
_x, ^axis ->
723-
List.duplicate(1, Nx.rank(i))
724-
725-
x, _ ->
726-
x
727-
end)
728-
|> List.flatten()
729-
730-
num_elements = Tuple.product(g.shape)
731-
732-
indices_for_axis =
733-
i
734-
|> Nx.reshape(indices_shape)
735-
|> Nx.tile(idx_tiling)
736-
737-
axis_offset = Nx.rank(i) - 1
738-
739-
indices =
740-
axes_range
741-
|> Enum.map(fn
742-
^axis ->
743-
indices_for_axis
744-
|> Nx.reshape({num_elements, 1})
745-
746-
current when current < axis ->
747-
indices_for_axis
748-
|> Nx.shape()
749-
|> Nx.iota(axis: current)
750-
|> Nx.reshape({num_elements, 1})
751-
752-
current when current > axis ->
753-
indices_for_axis
754-
|> Nx.shape()
755-
|> Nx.iota(axis: current + axis_offset)
756-
|> Nx.reshape({num_elements, 1})
757-
end)
758-
|> Nx.concatenate(axis: 1)
759-
760-
updates = Nx.reshape(g, {num_elements})
761-
762-
g =
763-
t
764-
|> Expr.broadcast(0, Nx.shape(t), Nx.axes(t))
765-
|> Nx.indexed_add(indices, updates)
766-
767-
[{t, g}]
768-
end
769-
770704
defp grad(:gather, [t, i, opts], _ans, g) do
771705
i_axes = opts[:axes]
772706
i_shape = i.shape

0 commit comments

Comments
 (0)