@@ -279,6 +279,8 @@ defmodule Axon do
279
279
alias __MODULE__ , as: Axon
280
280
alias Axon.Parameter
281
281
282
+ import Axon.Shared
283
+
282
284
require Logger
283
285
284
286
@ type t :: % __MODULE__ { }
@@ -380,15 +382,6 @@ defmodule Axon do
380
382
}
381
383
end
382
384
383
- defp split_inputs ( :container , [ inputs ] ) do
384
- { inputs , cache } =
385
- deep_map_reduce ( inputs , % { } , fn % Axon { output: id , nodes: nodes } , cache ->
386
- { id , Map . merge ( nodes , cache ) }
387
- end )
388
-
389
- { [ inputs ] , [ ] , [ :layer ] , cache }
390
- end
391
-
392
385
defp split_inputs ( _op , inputs ) do
393
386
Enum . reduce ( inputs , { [ ] , [ ] , [ ] , % { } } , fn
394
387
% Axon { output: layer_input , nodes: nodes } , { layers , params , args , cache } ->
@@ -704,62 +697,47 @@ defmodule Axon do
704
697
@ doc type: :special
705
698
def container ( container , opts \\ [ ] ) do
706
699
opts = Keyword . validate! ( opts , [ :name , :meta ] )
707
-
708
- layer ( :container , [ container ] , name: opts [ :name ] , meta: opts [ :meta ] , op_name: :container )
700
+ { structure_fn , nodes } = destructure ( container )
701
+ layer ( structure_fn , nodes , name: opts [ :name ] , meta: opts [ :meta ] , op_name: :container )
709
702
end
710
703
711
- # TODO: This should not be duplicated
712
- defp deep_new ( % Nx.Tensor { } = x , fun ) , do: fun . ( x )
713
-
714
- defp deep_new ( x , fun ) when is_number ( x ) , do: fun . ( x )
715
-
716
- defp deep_new ( map , fun ) do
717
- { cont , :ok } = Nx.Container . traverse ( map , :ok , & recur_traverse ( & 1 , & 2 , fun ) )
718
- cont
704
+ defp destructure ( container ) do
705
+ { structure , { nodes , _ } } = recur_destructure ( container , { [ ] , 0 } )
706
+ fun = restructure ( length ( nodes ) + 1 , structure )
707
+ { fun , Enum . reverse ( nodes ) }
719
708
end
720
709
721
- defp recur_traverse ( item , :ok , fun ) do
722
- case item do
723
- % Axon { } = t ->
724
- { fun . ( t ) , :ok }
725
-
726
- % { axon: :axon } = t ->
727
- { fun . ( t ) , :ok }
710
+ defp recur_destructure ( container , acc ) do
711
+ Nx.Container . traverse ( container , acc , fn value , { leaves , idx } ->
712
+ case value do
713
+ % Axon { } = leaf ->
714
+ { idx , { [ leaf | leaves ] , idx + 1 } }
728
715
729
- container ->
730
- { deep_new ( container , fun ) , :ok }
731
- end
716
+ container ->
717
+ recur_destructure ( container , { leaves , idx } )
718
+ end
719
+ end )
732
720
end
733
721
734
- defp deep_merge ( left , right , fun ) do
735
- case Nx.Container . traverse ( left , leaves ( right ) , & recur_merge ( & 1 , & 2 , fun ) ) do
736
- { merged , [ ] } ->
737
- merged
722
+ for i <- 0 .. 128 do
723
+ args = Macro . generate_arguments ( i , __MODULE__ )
738
724
739
- { _merged , _leftover } ->
740
- raise ArgumentError ,
741
- "unable to merge arguments with incompatible" <>
742
- " structure"
725
+ defp restructure ( unquote ( i ) , structure ) do
726
+ fn unquote_splicing ( args ) ->
727
+ args_tuple = { unquote_splicing ( args ) }
728
+ { container , :ok } = recur_restructure ( structure , args_tuple )
729
+ container
730
+ end
743
731
end
744
732
end
745
733
746
- defp leaves ( container ) do
747
- container
748
- |> Nx.Container . reduce ( [ ] , fn x , acc -> [ x | acc ] end )
749
- |> Enum . reverse ( )
750
- end
751
-
752
- defp recur_merge ( left , [ right | right_leaves ] , fun ) do
753
- case { left , right } do
754
- { % Nx.Tensor { } = left , % Nx.Tensor { } = right } ->
755
- { fun . ( left , right ) , right_leaves }
756
-
757
- { % Axon { } = left , % Axon { } = right } ->
758
- { fun . ( left , right ) , right_leaves }
759
-
760
- { left , right } ->
761
- { deep_merge ( left , right , fun ) , right_leaves }
762
- end
734
+ defp recur_restructure ( structure , args_tuple ) do
735
+ Nx.Container . traverse ( structure , :ok , fn value , :ok ->
736
+ case value do
737
+ idx when is_integer ( idx ) -> { elem ( args_tuple , idx ) , :ok }
738
+ container -> recur_restructure ( container , args_tuple )
739
+ end
740
+ end )
763
741
end
764
742
765
743
@ doc """
@@ -3644,35 +3622,31 @@ defmodule Axon do
3644
3622
end
3645
3623
3646
3624
@ doc """
3647
- Returns a model's output shape from the given input
3625
+ Returns a model's output template from the given input
3648
3626
template.
3627
+
3628
+ The output template gives you access to the output shape
3629
+ and type of the given input graph.
3649
3630
"""
3650
3631
@ doc type: :graph
3651
3632
def get_output_shape ( % Axon { } = axon , inputs , opts \\ [ ] ) do
3652
3633
{ init_fn , forward_fn } = build ( axon , opts ++ [ raise_on_none: false ] )
3653
3634
3654
- out =
3635
+ inputs =
3636
+ case inputs do
3637
+ % Nx.Tensor { } = input -> Nx . to_template ( input )
3638
+ inputs when is_map ( inputs ) -> Map . new ( inputs , fn { k , v } -> { k , Nx . to_template ( v ) } end )
3639
+ end
3640
+
3641
+ fun =
3655
3642
Nx.Defn . jit (
3656
3643
fn inputs ->
3657
3644
forward_fn . ( init_fn . ( inputs , Axon.ModelState . empty ( ) ) , inputs )
3658
3645
end ,
3659
3646
compiler: Axon.Defn
3660
- ) . ( inputs )
3661
-
3662
- safe_shape ( out )
3663
- end
3664
-
3665
- defp safe_shape ( container_or_tensor ) do
3666
- case container_or_tensor do
3667
- % Axon.None { } = none ->
3668
- none
3669
-
3670
- % Nx.Tensor { } = tensor ->
3671
- Nx . shape ( tensor )
3647
+ )
3672
3648
3673
- container ->
3674
- deep_new ( container , & Nx . shape / 1 )
3675
- end
3649
+ deep_new ( apply ( fun , [ inputs ] ) , & Nx . to_template / 1 )
3676
3650
end
3677
3651
3678
3652
@ doc """
@@ -3850,74 +3824,17 @@ defmodule Axon do
3850
3824
if MapSet . member? ( visited , id ) do
3851
3825
{ acc , visited }
3852
3826
else
3853
- % { op: op , parent: parents } = parent = nodes [ id ]
3827
+ % { parent: parents } = parent = nodes [ id ]
3854
3828
3855
3829
{ acc , visited } =
3856
- case op do
3857
- :container ->
3858
- [ container ] = parents
3859
-
3860
- deep_reduce ( container , { acc , visited } , fn pid , { acc , visited } ->
3861
- traverse_nodes ( pid , nodes , acc , visited )
3862
- end )
3863
-
3864
- _ ->
3865
- Enum . reduce ( parents , { acc , visited } , fn pid , { acc , visited } ->
3866
- traverse_nodes ( pid , nodes , acc , visited )
3867
- end )
3868
- end
3830
+ Enum . reduce ( parents , { acc , visited } , fn pid , { acc , visited } ->
3831
+ traverse_nodes ( pid , nodes , acc , visited )
3832
+ end )
3869
3833
3870
3834
{ [ parent | acc ] , MapSet . put ( visited , id ) }
3871
3835
end
3872
3836
end
3873
3837
3874
- # TODO: Do not duplicate
3875
- defp deep_reduce ( item , acc , fun ) when is_integer ( item ) do
3876
- fun . ( item , acc )
3877
- end
3878
-
3879
- defp deep_reduce ( map , acc , fun ) do
3880
- Nx.Container . reduce ( map , acc , & recur_deep_reduce ( & 1 , & 2 , fun ) )
3881
- end
3882
-
3883
- defp recur_deep_reduce ( value , acc , fun ) do
3884
- case value do
3885
- % Axon { } = val ->
3886
- fun . ( val , acc )
3887
-
3888
- % Nx.Tensor { } = val ->
3889
- fun . ( val , acc )
3890
-
3891
- % { axon: :axon } = val ->
3892
- fun . ( val , acc )
3893
-
3894
- val when is_integer ( val ) ->
3895
- fun . ( val , acc )
3896
-
3897
- val ->
3898
- deep_reduce ( val , acc , fun )
3899
- end
3900
- end
3901
-
3902
- defp deep_map_reduce ( leaf , acc , fun ) when is_integer ( leaf ) , do: fun . ( leaf , acc )
3903
-
3904
- defp deep_map_reduce ( container , acc , fun ) do
3905
- Nx.Container . traverse ( container , acc , & recur_deep_map_reduce ( & 1 , & 2 , fun ) )
3906
- end
3907
-
3908
- defp recur_deep_map_reduce ( leaf , acc , fun ) do
3909
- case leaf do
3910
- % Axon { } = leaf ->
3911
- fun . ( leaf , acc )
3912
-
3913
- % Nx.Tensor { } = leaf ->
3914
- fun . ( leaf , acc )
3915
-
3916
- container ->
3917
- deep_map_reduce ( container , acc , fun )
3918
- end
3919
- end
3920
-
3921
3838
@ doc """
3922
3839
Pops the top node off of the graph.
3923
3840
0 commit comments