Skip to content

Commit f48dcd1

Browse files
authored
Support blocks with multiple inputs (#574)
1 parent b6ab577 commit f48dcd1

File tree

3 files changed

+69
-12
lines changed

3 files changed

+69
-12
lines changed

lib/axon.ex

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,17 +746,26 @@ defmodule Axon do
746746
"""
747747
@doc type: :special
748748
def block(fun, opts \\ []) when is_function(fun) do
749+
{:arity, arity} = Function.info(fun, :arity)
749750
opts = Keyword.validate!(opts, [:name, :meta])
750751
block_id = System.unique_integer([:positive, :monotonic])
751752

752-
fn inputs ->
753+
block_fun(arity, fn inputs ->
753754
layer(:block, List.wrap(inputs),
754755
op_name: :block,
755756
name: opts[:name],
756757
meta: opts[:meta],
757758
block_fun: fun,
758759
block_id: block_id
759760
)
761+
end)
762+
end
763+
764+
for i <- 0..128 do
765+
args = Macro.generate_arguments(i, __MODULE__)
766+
767+
defp block_fun(unquote(i), callback) do
768+
fn unquote_splicing(args) -> callback.(unquote(args)) end
760769
end
761770
end
762771

lib/axon/compiler.ex

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -606,17 +606,17 @@ defmodule Axon.Compiler do
606606
%Axon.Node{
607607
id: id,
608608
op: :block,
609-
parent: [parent],
609+
parent: parents,
610610
opts: [block_fun: block_fun, block_id: block_id],
611611
name: name_fn
612612
},
613613
nodes,
614614
cache_and_counts,
615615
config
616616
) do
617-
{[parent_id], {cache, op_counts, block_cache, model_state_meta}} =
617+
{parent_ids, {cache, op_counts, block_cache, model_state_meta}} =
618618
Enum.map_reduce(
619-
[parent],
619+
parents,
620620
cache_and_counts,
621621
&to_model_funs(&1, nodes, &2, config)
622622
)
@@ -627,7 +627,8 @@ defmodule Axon.Compiler do
627627
{funs, name, block_cache, op_counts}
628628

629629
%{} ->
630-
funs = build(block_fun.(Axon.input("subgraph")), debug?: config.debug?)
630+
inputs = Enum.with_index(parents, fn _, i -> Axon.input("subgraph#{i}") end)
631+
funs = build(apply(block_fun, inputs), debug?: config.debug?)
631632
name = name_fn.(:block, op_counts)
632633
op_counts = Map.update(op_counts, :block, 1, fn x -> x + 1 end)
633634
{funs, name, Map.put(block_cache, block_id, {funs, name}), op_counts}
@@ -637,9 +638,9 @@ defmodule Axon.Compiler do
637638
# Recurse graph inputs and invoke cache to get parent results,
638639
# state, and result_cache and then apply dtype policy and hooks
639640
# to each input
640-
{[layer_input], {state, result_cache, none?}} =
641+
{layer_inputs, {state, result_cache, none?}} =
641642
Enum.map_reduce(
642-
[parent_id],
643+
parent_ids,
643644
{state, result_cache, false},
644645
fn parent_id, {state, result_cache, none?} ->
645646
{layer_input, {state, result_cache}} =
@@ -663,7 +664,13 @@ defmodule Axon.Compiler do
663664
{%Axon.None{}, {state, result_cache}}
664665
else
665666
block_params = params[block_name] || %{}
666-
result = apply(block_predict_fun, [Axon.ModelState.new(block_params), layer_input])
667+
668+
inputs =
669+
layer_inputs
670+
|> Enum.with_index()
671+
|> Map.new(fn {input, i} -> {"subgraph#{i}", input} end)
672+
673+
result = apply(block_predict_fun, [Axon.ModelState.new(block_params), inputs])
667674

668675
{out_result, out_state} =
669676
case result do
@@ -685,8 +692,8 @@ defmodule Axon.Compiler do
685692
end
686693

687694
init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
688-
{[parent_shape], {parent_params, result_cache, none?}} =
689-
Enum.map_reduce([parent_id], {%{}, result_cache, false}, fn
695+
{parent_shapes, {parent_params, result_cache, none?}} =
696+
Enum.map_reduce(parent_ids, {%{}, result_cache, false}, fn
690697
parent_id, {params, result_cache, none?} ->
691698
{parent_shape, {params, result_cache}} =
692699
call_init_cache(
@@ -706,8 +713,12 @@ defmodule Axon.Compiler do
706713
if none? do
707714
{%Axon.None{}, {parent_params, result_cache}}
708715
else
709-
template = Nx.broadcast(0.0, parent_shape)
710-
block_params = apply(block_init_fun, [template, Axon.ModelState.empty()])
716+
templates =
717+
parent_shapes
718+
|> Enum.with_index()
719+
|> Map.new(fn {shape, i} -> {"subgraph#{i}", Nx.broadcast(0.0, shape)} end)
720+
721+
block_params = apply(block_init_fun, [templates, Axon.ModelState.empty()])
711722

712723
params =
713724
if block_params == %{} do

test/axon/compiler_test.exs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5330,6 +5330,43 @@ defmodule CompilerTest do
53305330
input = random({1, 1})
53315331
assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k, b))
53325332
end
5333+
5334+
test "works with multiple block inputs" do
5335+
block =
5336+
Axon.block(fn x, y ->
5337+
dense = Axon.block(&Axon.dense(&1, 4))
5338+
Axon.add(dense.(y), dense.(x))
5339+
end)
5340+
5341+
input1 = Axon.input("input1")
5342+
input2 = Axon.input("input2")
5343+
5344+
model = block.(input1, input2) |> Axon.dense(1)
5345+
5346+
{init_fn, predict_fn} = Axon.build(model)
5347+
5348+
actual_predict_fn = fn %{"input1" => x, "input2" => y}, k1, b1, k2, b2 ->
5349+
x = Axon.Layers.dense(x, k1, b1)
5350+
y = Axon.Layers.dense(y, k1, b1)
5351+
5352+
x
5353+
|> Nx.add(y)
5354+
|> Axon.Layers.dense(k2, b2)
5355+
end
5356+
5357+
input = %{"input1" => Nx.tensor([[0.5]]), "input2" => Nx.tensor([[0.75]])}
5358+
5359+
assert %ModelState{
5360+
data: %{
5361+
"block_0" => %{
5362+
"block_0" => %{"dense_0" => %{"kernel" => k1, "bias" => b1}}
5363+
},
5364+
"dense_0" => %{"kernel" => k2, "bias" => b2}
5365+
}
5366+
} = params = init_fn.(input, ModelState.empty())
5367+
5368+
assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k1, b1, k2, b2))
5369+
end
53335370
end
53345371

53355372
describe "initializers" do

0 commit comments

Comments
 (0)