Skip to content

feat: Nx.elixir_call/3 #1627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,10 @@ defmodule EXLA.Defn do
EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type)
end

defp to_operator(:elixir_call, _, _, _) do
raise "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler."
end

defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do
n = opts[:length]
axis = opts[:axis]
Expand Down
61 changes: 61 additions & 0 deletions exla/test/exla/defn/elixir_call_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
defmodule EXLA.Defn.ElixirCallEvaluatorTest do
use ExUnit.Case, async: true
import Nx.Defn
import Nx.Testing

setup do
Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
Nx.default_backend(EXLA.Backend)
:ok
end

defn add_offset(x) do
out = %{x | type: Nx.Type.to_floating(x.type)}

Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts ->
Nx.add(Nx.as_type(t, :f32), opts[:offset])
end)
end

test "elixir_call with single output" do
x = Nx.iota({5})
y = add_offset(x)

expected = Nx.add(Nx.as_type(x, :f32), 10.0)
assert_equal(y, expected)
end

defn split_and_sum(x) do
fx = Nx.as_type(x, :f32)

out0 = fx
out1 = fx
out_template = {out0, out1}

{a, b} =
Nx.elixir_call(out_template, [fx], fn t ->
{Nx.multiply(t, 2.0), Nx.add(t, 1.0)}
end)

Nx.add(a, b)
end

test "elixir_call with tuple output" do
x = Nx.tensor([1, 2, 3])
y = split_and_sum(x)

fx = Nx.as_type(x, :f32)
expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0))
assert_equal(y, expected)
end

test "fails when using EXLA compiler" do
x = Nx.tensor([1, 2, 3])

assert_raise RuntimeError,
"Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler.",
fn ->
EXLA.jit_apply(&split_and_sum/1, [x])
end
end
end
55 changes: 55 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2196,6 +2196,61 @@ defmodule Nx do
list
end

@doc """
Invokes an Elixir function from within defn.

This function allows integrating arbitrary Elixir code into `defn` graphs.
It receives an output template (a tensor or a tuple of tensors) that
specifies the expected shapes, types, and names of the result, a list of
arguments to pass to the Elixir function, and the function itself.

Inside `defn`, this builds an expression node understood by compilers.
Outside `defn` or on backends without special support, it executes `fun`
directly and validates the result matches the template.
"""
@doc type: :backend
def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do
{:arity, arity} = Function.info(fun, :arity)
num_args = length(args)

if arity != num_args do
raise ArgumentError,
"expected #{arity} arguments, got #{num_args}"
end

backend = Nx.Shared.list_impl!(args)

cond do
function_exported?(backend, :elixir_call, 3) ->
output
|> backend.elixir_call(args, fun)
|> ensure_call_compatible!(output)

true ->
fun
|> apply(args)
|> ensure_call_compatible!(output)
end
end

defp ensure_call_compatible!(left, right) when tuple_size(left) == tuple_size(right) do
[Tuple.to_list(left), Tuple.to_list(right)]
|> Enum.zip_with(fn [l, r] -> ensure_call_compatible!(l, r) end)

left
end

defp ensure_call_compatible!(
%{shape: shape, type: type, names: names} = left,
%{shape: shape, type: type, names: names}
),
do: left

defp ensure_call_compatible!(left, right) do
raise ArgumentError,
"expected the elixir_call function to match the given output template #{inspect(right)}, got: #{inspect(left)}"
end

defp chunk([], data, type) do
match_types [type] do
<<match!(head, 0), tail::binary>> = data
Expand Down
10 changes: 10 additions & 0 deletions nx/lib/nx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ defmodule Nx.Backend do
"""
@callback optional(atom, [term], fun) :: tensor

@doc """
Invoked to execute a generic Elixir callback from within defn.

The backend may choose how to execute it. For example, EXLA can lower
to a custom_call that interacts with Erlang/Elixir via C; pure CPU
backends may call the function directly.
"""
@callback elixir_call(out :: tensor | tuple, [term], fun) :: tensor

@callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor
@callback cholesky(out :: tensor, tensor) :: tensor
@callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor
Expand All @@ -162,6 +171,7 @@ defmodule Nx.Backend do

@optional_callbacks [
optional: 3,
elixir_call: 3,
solve: 3,
determinant: 2,
logical_not: 2,
Expand Down
26 changes: 26 additions & 0 deletions nx/lib/nx/defn/evaluator.ex
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ defmodule Nx.Defn.Evaluator do
Map.put(cache, [:optional | id], optional_expr_cache)
end

defp compute_cache(:elixir_call, %{data: %Expr{args: args}}, state, cache) do
[in_args, _fun] = args

Enum.reduce(in_args, cache, fn
t, cache when is_list(t) -> cache
t, cache -> compute_cache(t, state, cache)
end)
end

defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do
%{parent_ids: parent_ids, current_ids: current_ids} = state

Expand Down Expand Up @@ -431,6 +440,23 @@ defmodule Nx.Defn.Evaluator do
end
end

defp eval_apply(
:elixir_call,
%{data: %Expr{args: [in_args, fun]}} = expr,
state,
caches
) do
{tensor_args, opts} = Enum.split_while(in_args, &(not is_list(&1)))
{evaluated_tensors, caches} = Enum.map_reduce(tensor_args, caches, &eval(&1, state, &2))
backend = Nx.Shared.list_impl!(evaluated_tensors)

if backend == Nx.Defn.Expr do
{expr, caches}
else
{apply(fun, evaluated_tensors ++ opts), caches}
end
end

defp eval_apply(op, %{vectorized_axes: [_ | _]} = ans, _state, _caches) do
raise "unexpected vectorized axes in evaluator for operation #{inspect(op)}: #{inspect(ans)}"
end
Expand Down
18 changes: 18 additions & 0 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ defmodule Nx.Defn.Expr do

* `attach_token(token(%Nx.Defn.Token{}), expr)`

* `elixir_call(name, args, fun)`

`defn` compilers must handle said nodes accordingly.
"""

Expand Down Expand Up @@ -384,6 +386,22 @@ defmodule Nx.Defn.Expr do
end
end

@impl true
def elixir_call(out, in_args, fun) do
{tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1)))
[%T{data: %Expr{context: context}} | _] = Enum.map(tensor_args, &to_expr/1)

case out do
t when is_struct(t, Nx.Tensor) ->
expr(t, context, :elixir_call, [in_args, fun])

tuple when is_tuple(tuple) ->
out_template = tuple_out(tuple_size(tuple))
expr_node = expr(out_template, context, :elixir_call, [in_args, fun])
tuple(expr_node, Tuple.to_list(tuple))
end
end

## Nx.Defn AST callbacks

@doc false
Expand Down
4 changes: 4 additions & 0 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ defmodule Nx.Defn.Grad do
acc
end

defp parents_args(:elixir_call, _expr, _id, acc, _parent_vectorized_names) do
acc
end

defp parents_args(
:optional,
%{data: %{args: [call, _expr, callback]}} = t,
Expand Down
15 changes: 15 additions & 0 deletions nx/lib/nx/defn/tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@ defmodule Nx.Defn.Tree do
{[call, expr, callback], acc}
end

def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, _type, acc, fun) do
[in_args, callback] = args

{in_args, acc} =
Enum.map_reduce(in_args, acc, fn t, acc ->
if is_list(t) do
{t, acc}
else
Composite.traverse(t, acc, fun)
end
end)

{[in_args, callback], acc}
end

def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do
{hooks, acc} =
Enum.map_reduce(token.hooks, acc, fn %{expr: expr} = token, acc ->
Expand Down
49 changes: 49 additions & 0 deletions nx/test/nx/defn/elixir_call_evaluator_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
defmodule Nx.Defn.ElixirCallEvaluatorTest do
use ExUnit.Case, async: true
import Nx.Defn

setup do
Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
:ok
end

defn add_offset(x) do
out = %{x | type: Nx.Type.to_floating(x.type)}

Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts ->
Nx.add(Nx.as_type(t, :f32), opts[:offset])
end)
end

test "elixir_call with single output" do
x = Nx.iota({5})
y = add_offset(x)

expected = Nx.add(Nx.as_type(x, :f32), 10.0)
assert Nx.all_close(y, expected) |> Nx.to_number() == 1
end

defn split_and_sum(x) do
fx = Nx.as_type(x, :f32)

out0 = fx
out1 = fx
out_template = {out0, out1}

{a, b} =
Nx.elixir_call(out_template, [fx], fn t ->
{Nx.multiply(t, 2.0), Nx.add(t, 1.0)}
end)

Nx.add(a, b)
end

test "elixir_call with tuple output" do
x = Nx.tensor([1, 2, 3])
y = split_and_sum(x)

fx = Nx.as_type(x, :f32)
expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0))
assert expected == y
end
end
4 changes: 2 additions & 2 deletions torchx/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do

defp deps do
[
{:nx, "~> 0.10.0"},
# {:nx, path: "../nx"},
# {:nx, "~> 0.10.0"},
{:nx, path: "../nx"},
{:ex_doc, "~> 0.29", only: :docs}
]
end
Expand Down
51 changes: 51 additions & 0 deletions torchx/test/torchx/defn/elixir_call_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
defmodule Torchx.Defn.ElixirCallEvaluatorTest do
use ExUnit.Case, async: true
import Nx.Defn
import Nx.Testing

setup do
Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
Nx.default_backend(Torchx.Backend)
:ok
end

defn add_offset(x) do
out = %{x | type: Nx.Type.to_floating(x.type)}

Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts ->
Nx.add(Nx.as_type(t, :f32), opts[:offset])
end)
end

test "elixir_call with single output" do
x = Nx.iota({5})
y = add_offset(x)

expected = Nx.add(Nx.as_type(x, :f32), 10.0)
assert_equal(y, expected)
end

defn split_and_sum(x) do
fx = Nx.as_type(x, :f32)

out0 = fx
out1 = fx
out_template = {out0, out1}

{a, b} =
Nx.elixir_call(out_template, [fx], fn t ->
{Nx.multiply(t, 2.0), Nx.add(t, 1.0)}
end)

Nx.add(a, b)
end

test "elixir_call with tuple output" do
x = Nx.tensor([1, 2, 3])
y = split_and_sum(x)

fx = Nx.as_type(x, :f32)
expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0))
assert_equal(y, expected)
end
end