Skip to content

Commit b8f2254

Browse files
committed
Make take an optional callback
1 parent 50ecf0a commit b8f2254

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

exla/lib/exla/defn.ex

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,49 @@ defmodule EXLA.Defn do
598598

599599
defp cached_recur_operator(
600600
:optional,
601-
%T{data: %Expr{args: [%{data: %{op: :top_k, args: [tensor, opts]}}, expr, _callback]}} =
602-
_out,
601+
%T{
602+
data: %Expr{
603+
args: [%{data: %{op: :take, args: [tensor, indices, opts]}}, expr, _callback]
604+
}
605+
},
606+
state,
607+
cache
608+
) do
609+
axis = opts[:axis]
610+
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
611+
{indices, cache} = recur_operator(indices, state, cache) |> unwrap_single_tensor!()
612+
613+
tensor_rank = tensor |> op_shape() |> tuple_size()
614+
indices_rank = indices |> op_shape() |> tuple_size()
615+
result_rank = tensor_rank - 1 + indices_rank
616+
617+
index_vector_dim = indices_rank
618+
slice_sizes = tensor |> op_shape() |> put_elem(axis, 1) |> Tuple.to_list()
619+
620+
{left, right} = result_rank |> axes_for_rank() |> Enum.split(axis)
621+
offset_dims = left ++ Enum.drop(right, indices_rank)
622+
623+
collapsed_slice_dims = [axis]
624+
start_index_map = [axis]
625+
626+
result =
627+
Value.gather(
628+
tensor,
629+
indices,
630+
index_vector_dim,
631+
slice_sizes,
632+
offset_dims,
633+
collapsed_slice_dims,
634+
start_index_map,
635+
expr_to_typespec(expr)
636+
)
637+
638+
{result, cache}
639+
end
640+
641+
defp cached_recur_operator(
642+
:optional,
643+
%T{data: %Expr{args: [%{data: %{op: :top_k, args: [tensor, opts]}}, expr, _callback]}},
603644
state,
604645
cache
605646
) do
@@ -612,26 +653,24 @@ defmodule EXLA.Defn do
612653

613654
defp cached_recur_operator(
614655
:optional,
615-
%T{data: %Expr{args: [%{data: %{op: :fft2, args: [tensor, opts]}}, _expr, _callback]}} =
616-
out,
656+
%T{data: %Expr{args: [%{data: %{op: :fft2, args: [tensor, opts]}}, expr, _callback]}},
617657
state,
618658
cache
619659
) do
620660
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
621661

622-
{fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], out, state), cache}
662+
{fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr, state), cache}
623663
end
624664

625665
defp cached_recur_operator(
626666
:optional,
627-
%T{data: %Expr{args: [%{data: %{op: :ifft2, args: [tensor, opts]}}, _expr, _callback]}} =
628-
out,
667+
%T{data: %Expr{args: [%{data: %{op: :ifft2, args: [tensor, opts]}}, expr, _callback]}},
629668
state,
630669
cache
631670
) do
632671
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
633672

634-
{fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], out, state), cache}
673+
{fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache}
635674
end
636675

637676
defp cached_recur_operator(:optional, %T{data: %Expr{args: args}}, state, cache) do

nx/lib/nx.ex

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14130,17 +14130,20 @@ defmodule Nx do
1413014130
else
1413114131
tensor = devectorize(tensor, keep_names: false)
1413214132
indices = devectorize(indices, keep_names: false)
14133-
gather_indices = new_axis(indices, rank(indices))
14133+
out = %{tensor | shape: inner_shape, names: inner_names}
1413414134

14135-
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
14136-
{leading, trailing} = Enum.split(tensor_axes, axis)
14135+
Nx.Shared.optional(:take, [tensor, indices, [axis: axis]], out, fn tensor, indices, _opts ->
14136+
gather_indices = new_axis(indices, rank(indices))
14137+
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
14138+
{leading, trailing} = Enum.split(tensor_axes, axis)
1413714139

14138-
transpose_axes = leading ++ indices_axes ++ trailing
14140+
transpose_axes = leading ++ indices_axes ++ trailing
1413914141

14140-
tensor
14141-
|> gather(gather_indices, axes: [axis])
14142-
|> transpose(axes: transpose_axes)
14143-
|> reshape(inner_shape, names: inner_names)
14142+
tensor
14143+
|> gather(gather_indices, axes: [axis])
14144+
|> transpose(axes: transpose_axes)
14145+
|> reshape(inner_shape, names: inner_names)
14146+
end)
1414414147
end
1414514148
end
1414614149

0 commit comments

Comments
 (0)