diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 413c38ce45..7a7b3865a8 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1209,6 +1209,10 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end + defp to_operator(:elixir_call, _, _, _) do + raise "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler." + end + defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] diff --git a/exla/test/exla/defn/elixir_call_test.exs b/exla/test/exla/defn/elixir_call_test.exs new file mode 100644 index 0000000000..add051a3f6 --- /dev/null +++ b/exla/test/exla/defn/elixir_call_test.exs @@ -0,0 +1,61 @@ +defmodule EXLA.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + Nx.default_backend(EXLA.Backend) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end + + test "fails when using EXLA compiler" do + x = Nx.tensor([1, 2, 3]) + + assert_raise RuntimeError, + "Nx.elixir_call/3 is not supported yet. Use Nx.Defn.Evaluator as your compiler.", + fn -> + EXLA.jit_apply(&split_and_sum/1, [x]) + end + end +end diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 091372d005..715a149286 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -2196,6 +2196,61 @@ defmodule Nx do list end + @doc """ + Invokes an Elixir function from within defn. + + This function allows integrating arbitrary Elixir code into `defn` graphs. + It receives an output template (a tensor or a tuple of tensors) that + specifies the expected shapes, types, and names of the result, a list of + arguments to pass to the Elixir function, and the function itself. + + Inside `defn`, this builds an expression node understood by compilers. + Outside `defn` or on backends without special support, it executes `fun` + directly and validates the result matches the template. + """ + @doc type: :backend + def elixir_call(output, args, fun) when is_list(args) and is_function(fun) do + {:arity, arity} = Function.info(fun, :arity) + num_args = length(args) + + if arity != num_args do + raise ArgumentError, + "expected #{arity} arguments, got #{num_args}" + end + + backend = Nx.Shared.list_impl!(args) + + cond do + function_exported?(backend, :elixir_call, 3) -> + output + |> backend.elixir_call(args, fun) + |> ensure_call_compatible!(output) + + true -> + fun + |> apply(args) + |> ensure_call_compatible!(output) + end + end + + defp ensure_call_compatible!(left, right) when tuple_size(left) == tuple_size(right) do + [Tuple.to_list(left), Tuple.to_list(right)] + |> Enum.zip_with(fn [l, r] -> ensure_call_compatible!(l, r) end) + + left + end + + defp ensure_call_compatible!( + %{shape: shape, type: type, names: names} = left, + %{shape: shape, type: type, names: names} + ), + do: left + + defp ensure_call_compatible!(left, right) do + raise ArgumentError, + "expected the elixir_call function to match the given output template #{inspect(right)}, got: #{inspect(left)}" + end + defp chunk([], data, type) do match_types [type] do <> = data diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 3c463ba237..f8556ce308 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -142,6 +142,15 @@ defmodule Nx.Backend do """ @callback optional(atom, [term], fun) :: tensor + @doc """ + Invoked to execute a generic Elixir callback from within defn. + + The backend may choose how to execute it. For example, EXLA can lower + to a custom_call that interacts with Erlang/Elixir via C; pure CPU + backends may call the function directly. + """ + @callback elixir_call(out :: tensor | tuple, [term], fun) :: tensor + @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor @callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor @@ -162,6 +171,7 @@ defmodule Nx.Backend do @optional_callbacks [ optional: 3, + elixir_call: 3, solve: 3, determinant: 2, logical_not: 2, diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index d028ec6a63..c913f4ec3c 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -175,6 +175,15 @@ defmodule Nx.Defn.Evaluator do Map.put(cache, [:optional | id], optional_expr_cache) end + defp compute_cache(:elixir_call, %{data: %Expr{args: args}}, state, cache) do + [in_args, _fun] = args + + Enum.reduce(in_args, cache, fn + t, cache when is_list(t) -> cache + t, cache -> compute_cache(t, state, cache) + end) + end + defp compute_cache(:cond, %{data: %Expr{args: [clauses, last], id: id}}, state, cache) do %{parent_ids: parent_ids, current_ids: current_ids} = state @@ -431,6 +440,23 @@ defmodule Nx.Defn.Evaluator do end end + defp eval_apply( + :elixir_call, + %{data: %Expr{args: [in_args, fun]}} = expr, + state, + caches + ) do + {tensor_args, opts} = Enum.split_while(in_args, &(not is_list(&1))) + {evaluated_tensors, caches} = Enum.map_reduce(tensor_args, caches, &eval(&1, state, &2)) + backend = Nx.Shared.list_impl!(evaluated_tensors) + + if backend == Nx.Defn.Expr do + {expr, caches} + else + {apply(fun, evaluated_tensors ++ opts), caches} + end + end + defp eval_apply(op, %{vectorized_axes: [_ | _]} = ans, _state, _caches) do raise "unexpected vectorized axes in evaluator for operation #{inspect(op)}: #{inspect(ans)}" end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 782e4a07fd..1d488df888 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -41,6 +41,8 @@ defmodule Nx.Defn.Expr do * `attach_token(token(%Nx.Defn.Token{}), expr)` + * `elixir_call(name, args, fun)` + `defn` compilers must handle said nodes accordingly. """ @@ -384,6 +386,22 @@ defmodule Nx.Defn.Expr do end end + @impl true + def elixir_call(out, in_args, fun) do + {tensor_args, _opts} = Enum.split_while(in_args, &(not is_list(&1))) + [%T{data: %Expr{context: context}} | _] = Enum.map(tensor_args, &to_expr/1) + + case out do + t when is_struct(t, Nx.Tensor) -> + expr(t, context, :elixir_call, [in_args, fun]) + + tuple when is_tuple(tuple) -> + out_template = tuple_out(tuple_size(tuple)) + expr_node = expr(out_template, context, :elixir_call, [in_args, fun]) + tuple(expr_node, Tuple.to_list(tuple)) + end + end + ## Nx.Defn AST callbacks @doc false diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 2941889f98..8c72d0fed0 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -122,6 +122,10 @@ defmodule Nx.Defn.Grad do acc end + defp parents_args(:elixir_call, _expr, _id, acc, _parent_vectorized_names) do + acc + end + defp parents_args( :optional, %{data: %{args: [call, _expr, callback]}} = t, diff --git a/nx/lib/nx/defn/tree.ex b/nx/lib/nx/defn/tree.ex index 582b9d4689..733a131e4f 100644 --- a/nx/lib/nx/defn/tree.ex +++ b/nx/lib/nx/defn/tree.ex @@ -192,6 +192,21 @@ defmodule Nx.Defn.Tree do {[call, expr, callback], acc} end + def apply_args(%T{data: %Expr{op: :elixir_call, args: args}}, _type, acc, fun) do + [in_args, callback] = args + + {in_args, acc} = + Enum.map_reduce(in_args, acc, fn t, acc -> + if is_list(t) do + {t, acc} + else + Composite.traverse(t, acc, fun) + end + end) + + {[in_args, callback], acc} + end + def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do {hooks, acc} = Enum.map_reduce(token.hooks, acc, fn %{expr: expr} = token, acc -> diff --git a/nx/test/nx/defn/elixir_call_evaluator_test.exs b/nx/test/nx/defn/elixir_call_evaluator_test.exs new file mode 100644 index 0000000000..92fad6b431 --- /dev/null +++ b/nx/test/nx/defn/elixir_call_evaluator_test.exs @@ -0,0 +1,49 @@ +defmodule Nx.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert Nx.all_close(y, expected) |> Nx.to_number() == 1 + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert expected == y + end +end diff --git a/torchx/mix.exs b/torchx/mix.exs index fa5531e541..5174e59cdb 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do defp deps do [ - {:nx, "~> 0.10.0"}, - # {:nx, path: "../nx"}, + # {:nx, "~> 0.10.0"}, + {:nx, path: "../nx"}, {:ex_doc, "~> 0.29", only: :docs} ] end diff --git a/torchx/test/torchx/defn/elixir_call_test.exs b/torchx/test/torchx/defn/elixir_call_test.exs new file mode 100644 index 0000000000..9c504fa6c8 --- /dev/null +++ b/torchx/test/torchx/defn/elixir_call_test.exs @@ -0,0 +1,51 @@ +defmodule Torchx.Defn.ElixirCallEvaluatorTest do + use ExUnit.Case, async: true + import Nx.Defn + import Nx.Testing + + setup do + Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) + Nx.default_backend(Torchx.Backend) + :ok + end + + defn add_offset(x) do + out = %{x | type: Nx.Type.to_floating(x.type)} + + Nx.elixir_call(out, [x, [offset: 10.0]], fn t, opts -> + Nx.add(Nx.as_type(t, :f32), opts[:offset]) + end) + end + + test "elixir_call with single output" do + x = Nx.iota({5}) + y = add_offset(x) + + expected = Nx.add(Nx.as_type(x, :f32), 10.0) + assert_equal(y, expected) + end + + defn split_and_sum(x) do + fx = Nx.as_type(x, :f32) + + out0 = fx + out1 = fx + out_template = {out0, out1} + + {a, b} = + Nx.elixir_call(out_template, [fx], fn t -> + {Nx.multiply(t, 2.0), Nx.add(t, 1.0)} + end) + + Nx.add(a, b) + end + + test "elixir_call with tuple output" do + x = Nx.tensor([1, 2, 3]) + y = split_and_sum(x) + + fx = Nx.as_type(x, :f32) + expected = Nx.add(Nx.multiply(fx, 2.0), Nx.add(fx, 1.0)) + assert_equal(y, expected) + end +end