@@ -606,17 +606,17 @@ defmodule Axon.Compiler do
606
606
% Axon.Node {
607
607
id: id ,
608
608
op: :block ,
609
- parent: [ parent ] ,
609
+ parent: parents ,
610
610
opts: [ block_fun: block_fun , block_id: block_id ] ,
611
611
name: name_fn
612
612
} ,
613
613
nodes ,
614
614
cache_and_counts ,
615
615
config
616
616
) do
617
- { [ parent_id ] , { cache , op_counts , block_cache , model_state_meta } } =
617
+ { parent_ids , { cache , op_counts , block_cache , model_state_meta } } =
618
618
Enum . map_reduce (
619
- [ parent ] ,
619
+ parents ,
620
620
cache_and_counts ,
621
621
& to_model_funs ( & 1 , nodes , & 2 , config )
622
622
)
@@ -627,7 +627,8 @@ defmodule Axon.Compiler do
627
627
{ funs , name , block_cache , op_counts }
628
628
629
629
% { } ->
630
- funs = build ( block_fun . ( Axon . input ( "subgraph" ) ) , debug?: config . debug? )
630
+ inputs = Enum . with_index ( parents , fn _ , i -> Axon . input ( "subgraph#{ i } " ) end )
631
+ funs = build ( apply ( block_fun , inputs ) , debug?: config . debug? )
631
632
name = name_fn . ( :block , op_counts )
632
633
op_counts = Map . update ( op_counts , :block , 1 , fn x -> x + 1 end )
633
634
{ funs , name , Map . put ( block_cache , block_id , { funs , name } ) , op_counts }
@@ -637,9 +638,9 @@ defmodule Axon.Compiler do
637
638
# Recurse graph inputs and invoke cache to get parent results,
638
639
# state, and result_cache and then apply dtype policy and hooks
639
640
# to each input
640
- { [ layer_input ] , { state , result_cache , none? } } =
641
+ { layer_inputs , { state , result_cache , none? } } =
641
642
Enum . map_reduce (
642
- [ parent_id ] ,
643
+ parent_ids ,
643
644
{ state , result_cache , false } ,
644
645
fn parent_id , { state , result_cache , none? } ->
645
646
{ layer_input , { state , result_cache } } =
@@ -663,7 +664,13 @@ defmodule Axon.Compiler do
663
664
{ % Axon.None { } , { state , result_cache } }
664
665
else
665
666
block_params = params [ block_name ] || % { }
666
- result = apply ( block_predict_fun , [ Axon.ModelState . new ( block_params ) , layer_input ] )
667
+
668
+ inputs =
669
+ layer_inputs
670
+ |> Enum . with_index ( )
671
+ |> Map . new ( fn { input , i } -> { "subgraph#{ i } " , input } end )
672
+
673
+ result = apply ( block_predict_fun , [ Axon.ModelState . new ( block_params ) , inputs ] )
667
674
668
675
{ out_result , out_state } =
669
676
case result do
@@ -685,8 +692,8 @@ defmodule Axon.Compiler do
685
692
end
686
693
687
694
init_fun = fn template , cache , result_cache , fn_stacktrace , keys ->
688
- { [ parent_shape ] , { parent_params , result_cache , none? } } =
689
- Enum . map_reduce ( [ parent_id ] , { % { } , result_cache , false } , fn
695
+ { parent_shapes , { parent_params , result_cache , none? } } =
696
+ Enum . map_reduce ( parent_ids , { % { } , result_cache , false } , fn
690
697
parent_id , { params , result_cache , none? } ->
691
698
{ parent_shape , { params , result_cache } } =
692
699
call_init_cache (
@@ -706,8 +713,12 @@ defmodule Axon.Compiler do
706
713
if none? do
707
714
{ % Axon.None { } , { parent_params , result_cache } }
708
715
else
709
- template = Nx . broadcast ( 0.0 , parent_shape )
710
- block_params = apply ( block_init_fun , [ template , Axon.ModelState . empty ( ) ] )
716
+ templates =
717
+ parent_shapes
718
+ |> Enum . with_index ( )
719
+ |> Map . new ( fn { shape , i } -> { "subgraph#{ i } " , Nx . broadcast ( 0.0 , shape ) } end )
720
+
721
+ block_params = apply ( block_init_fun , [ templates , Axon.ModelState . empty ( ) ] )
711
722
712
723
params =
713
724
if block_params == % { } do
0 commit comments