Skip to content

Commit 44313f6

Browse files
committed
Allow dynamic calls in defn
1 parent 83419bb commit 44313f6

File tree

2 files changed

+15
-23
lines changed

2 files changed

+15
-23
lines changed

nx/lib/nx/defn/compiler.ex

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -568,19 +568,6 @@ defmodule Nx.Defn.Compiler do
568568
{{{:., dot_meta, [fun]}, meta, args}, state}
569569
end
570570

571-
# TODO: Remove me once transform/2 is removed.
572-
defp normalize({{:., _, [Nx.Defn.Kernel, :transform]} = call, meta, [ast, fun]}, state) do
573-
{ast, state} = normalize(ast, state)
574-
575-
fun =
576-
Macro.prewalk(fun, fn
577-
var when is_var(var) -> normalize_var(var)
578-
node -> node
579-
end)
580-
581-
{{call, meta, [ast, fun]}, state}
582-
end
583-
584571
defp normalize({{:., _, [Nx.Defn.Kernel, :hook]} = call, meta, [ast | rest]}, state) do
585572
{ast, state} = normalize(ast, state)
586573
{{call, meta, [ast | rest]}, state}
@@ -647,11 +634,10 @@ defmodule Nx.Defn.Compiler do
647634
state}
648635
end
649636

650-
defp normalize({{:., dot_meta, [remote, name]}, meta, args}, state)
651-
# TODO: Remove args == [] once we require Elixir version where args are nil
652-
when is_atom(name) and (args == nil or args == []) do
637+
defp normalize({{:., dot_meta, [remote, name]}, meta, args}, state) when is_atom(name) do
653638
{remote, state} = normalize(remote, state)
654-
{{{:., dot_meta, [Map, :fetch!]}, meta, [remote, name]}, state}
639+
{args, state} = normalize_list(args, state)
640+
{{{:., dot_meta, [remote, name]}, meta, args}, state}
655641
end
656642

657643
defp normalize({left, right}, state) do

nx/test/nx/defn_test.exs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -953,34 +953,40 @@ defmodule Nx.DefnTest do
953953

954954
describe "remote functions" do
955955
defmodule Remote do
956-
defn(add_two(c, d), do: c + d)
956+
defn add_two(c, d), do: c + d
957957
end
958958

959-
defn(add_two_remote(a, b), do: Remote.add_two(a, b))
959+
defn add_two_remote(a, b), do: Remote.add_two(a, b)
960960

961961
test "public" do
962962
assert %T{data: %Expr{op: :add, args: [_, _]}} = add_two_remote(1, 2)
963963
end
964964

965-
defn(add_two_unknown(a, b), do: Nx.DefnTest.unknown(a, b))
965+
defn add_two_dynamic(a, b, opts \\ []), do: opts[:remote].add_two(a, b)
966966

967-
def not_defn(a, b), do: Nx.add(a, b)
968-
defn(add_two_not_defn(a, b), do: Nx.DefnTest.not_defn(a, b))
967+
test "dynamic" do
968+
assert %T{data: %Expr{op: :add, args: [_, _]}} = add_two_remote(1, 2)
969+
end
969970

970-
defn(add_two_io(a, b), do: IO.inspect({a, b}))
971+
defn add_two_unknown(a, b), do: Nx.DefnTest.unknown(a, b)
971972

972973
test "undefined remote" do
973974
assert_raise UndefinedFunctionError,
974975
"function Nx.DefnTest.unknown/2 is undefined or private",
975976
fn -> add_two_unknown(1, 2) end
976977
end
977978

979+
def not_defn(a, b), do: Nx.add(a, b)
980+
defn add_two_not_defn(a, b), do: Nx.DefnTest.not_defn(a, b)
981+
978982
test "not defn remote" do
979983
assert_raise RuntimeError,
980984
"cannot invoke Nx.DefnTest.not_defn/2 inside defn because it was not defined with defn",
981985
fn -> add_two_not_defn(1, 2) end
982986
end
983987

988+
defn add_two_io(a, b), do: IO.inspect({a, b})
989+
984990
test "IO remote" do
985991
assert_raise RuntimeError,
986992
"cannot invoke IO.inspect/1 inside defn because it was not defined with defn. " <>

0 commit comments

Comments
 (0)