@@ -281,9 +281,6 @@ defmodule Axon do
281
281
282
282
require Logger
283
283
284
- # Axon serialization version
285
- @ file_version 1
286
-
287
284
@ type t :: % __MODULE__ { }
288
285
289
286
defstruct [
@@ -417,16 +414,18 @@ defmodule Axon do
417
414
}
418
415
end
419
416
420
- def param ( name , shape , opts ) when is_tuple ( shape ) or is_function ( shape ) do
421
- opts = Keyword . validate! ( opts , initializer: :glorot_uniform , type: { :f , 32 } )
417
+ def param ( name , shape , opts ) when is_binary ( name ) and ( is_tuple ( shape ) or is_function ( shape ) ) do
418
+ opts = Keyword . validate! ( opts , initializer: :glorot_uniform , type: { :f , 32 } , kind: :parameter )
422
419
initializer = validate_initializer! ( opts [ :initializer ] )
423
420
type = opts [ :type ] || { :f , 32 }
421
+ kind = opts [ :kind ] || :parameter
424
422
425
423
% Axon.Parameter {
426
424
name: name ,
427
425
shape: shape ,
428
426
type: type ,
429
- initializer: initializer
427
+ initializer: initializer ,
428
+ kind: kind
430
429
}
431
430
end
432
431
@@ -586,7 +585,7 @@ defmodule Axon do
586
585
iex> inp1 = Axon.input("input_0", shape: {nil, 1})
587
586
iex> inp2 = Axon.input("input_1", shape: {nil, 2})
588
587
iex> model = Axon.container(%{a: inp1, b: inp2})
589
- iex> %{a: a, b: b} = Axon.predict(model, %{} , %{
588
+ iex> %{a: a, b: b} = Axon.predict(model, Axon.ModelState.empty() , %{
590
589
...> "input_0" => Nx.tensor([[1.0]]),
591
590
...> "input_1" => Nx.tensor([[1.0, 2.0]])
592
591
...> })
@@ -667,42 +666,6 @@ defmodule Axon do
667
666
end
668
667
end
669
668
670
- @ doc """
671
- Wraps an Axon model into a namespace.
672
-
673
- A namespace is a part of an Axon model which is meant to
674
- be a self-contained collection of Axon layers. Namespaces
675
- are guaranteed to always generate with the same internal
676
- layer names and can be re-used universally across models.
677
-
678
- Namespaces are most useful for containing large collections
679
- of layers and offering a straightforward means for accessing
680
- the parameters of individual model components. A common application
681
- of namespaces is to use them in with a pre-trained model for
682
- fine-tuning:
683
-
684
- {base, resnet_params} = resnet()
685
- base = base |> Axon.namespace("resnet")
686
-
687
- model = base |> Axon.dense(1)
688
- {init_fn, predict_fn} = Axon.build(model)
689
-
690
- init_fn.(Nx.template({1, 3, 224, 224}, {:f, 32}), %{"resnset" => resnet_params})
691
-
692
- Notice you can use `init_fn` in conjunction with namespaces
693
- to specify which portion of a model you'd like to initialize
694
- from a fixed starting point.
695
-
696
- Namespaces have fixed names, which means it's easy to run into namespace
697
- collisions. Re-using namespaces, re-using inner parts of a namespace,
698
- and attempting to share layers between namespaces are still sharp
699
- edges in namespace usage.
700
- """
701
- @ doc type: :special
702
- def namespace ( % Axon { } = axon , name ) when is_binary ( name ) do
703
- layer ( :namespace , [ axon ] , name: name )
704
- end
705
-
706
669
@ doc """
707
670
Returns a function which represents a self-contained re-usable block
708
671
of operations in a neural network. All parameters in the block are
@@ -1561,7 +1524,8 @@ defmodule Axon do
1561
1524
key_state =
1562
1525
param ( "key" , fn _ -> { 2 } end ,
1563
1526
type: { :u , 32 } ,
1564
- initializer: fn _ , _ -> Nx.Random . key ( seed ) end
1527
+ initializer: fn _ , _ -> Nx.Random . key ( seed ) end ,
1528
+ kind: :state
1565
1529
)
1566
1530
1567
1531
layer ( dropout , [ x , key_state ] ,
@@ -1867,8 +1831,8 @@ defmodule Axon do
1867
1831
gamma = param ( "gamma" , gamma_shape , initializer: opts [ :gamma_initializer ] )
1868
1832
beta = param ( "beta" , beta_shape , initializer: opts [ :beta_initializer ] )
1869
1833
1870
- mean = param ( "mean" , mean_shape , initializer: :zeros )
1871
- var = param ( "var" , var_shape , initializer: :ones )
1834
+ mean = param ( "mean" , mean_shape , initializer: :zeros , kind: :state )
1835
+ var = param ( "var" , var_shape , initializer: :ones , kind: :state )
1872
1836
1873
1837
layer (
1874
1838
norm ,
@@ -3006,14 +2970,16 @@ defmodule Axon do
3006
2970
key_state =
3007
2971
param ( "key" , fn _ -> { 2 } end ,
3008
2972
type: { :u , 32 } ,
3009
- initializer: fn _ , _ -> Nx.Random . key ( seed ) end
2973
+ initializer: fn _ , _ -> Nx.Random . key ( seed ) end ,
2974
+ kind: :state
3010
2975
)
3011
2976
3012
2977
name =
3013
2978
case parent_name do
3014
2979
nil ->
3015
2980
fn _ , op_counts ->
3016
- "lstm_#{ op_counts [ rnn_type ] } _#{ state_name } _hidden_state"
2981
+ count = op_counts [ rnn_type ] || 0
2982
+ "#{ Atom . to_string ( rnn_type ) } _#{ count } _#{ state_name } _hidden_state"
3017
2983
end
3018
2984
3019
2985
parent_name when is_binary ( parent_name ) ->
@@ -3042,9 +3008,16 @@ defmodule Axon do
3042
3008
3043
3009
arity == 3 ->
3044
3010
fun =
3045
- fn inputs , key , _opts ->
3011
+ fn inputs , key , opts ->
3046
3012
shape = Axon.Shape . rnn_hidden_state ( Nx . shape ( inputs ) , units , rnn_type )
3047
- initializer . ( shape , { :f , 32 } , key )
3013
+ keys = Nx.Random . split ( key )
3014
+ out = initializer . ( shape , { :f , 32 } , keys [ 1 ] )
3015
+
3016
+ if opts [ :mode ] == :train do
3017
+ % Axon.StatefulOutput { output: out , state: % { "key" => keys [ 0 ] } }
3018
+ else
3019
+ out
3020
+ end
3048
3021
end
3049
3022
3050
3023
{ fun , [ x , key_state ] }
@@ -3168,6 +3141,7 @@ defmodule Axon do
3168
3141
the update process.
3169
3142
"""
3170
3143
@ doc type: :model
3144
+ @ deprecated "Use Axon.ModelState.freeze/2 instead"
3171
3145
def freeze ( model , fun_or_predicate \\ :all ) do
3172
3146
freeze ( model , fun_or_predicate , true )
3173
3147
end
@@ -3240,6 +3214,7 @@ defmodule Axon do
3240
3214
the update process.
3241
3215
"""
3242
3216
@ doc type: :model
3217
+ @ deprecated "Use Axon.ModelState.freeze/2 instead"
3243
3218
def unfreeze ( model , fun_or_predicate \\ :all ) do
3244
3219
freeze ( model , fun_or_predicate , false )
3245
3220
end
@@ -3410,7 +3385,7 @@ defmodule Axon do
3410
3385
out =
3411
3386
Nx.Defn . jit (
3412
3387
fn inputs ->
3413
- forward_fn . ( init_fn . ( inputs , % { } ) , inputs )
3388
+ forward_fn . ( init_fn . ( inputs , Axon.ModelState . empty ( ) ) , inputs )
3414
3389
end ,
3415
3390
compiler: Axon.Defn
3416
3391
) . ( inputs )
@@ -3864,158 +3839,6 @@ defmodule Axon do
3864
3839
end
3865
3840
end
3866
3841
3867
- # Serialization
3868
-
3869
- @ doc """
3870
- Serializes a model and its parameters for persisting
3871
- models to disk or elsewhere.
3872
-
3873
- Model and parameters are serialized as a tuple, where the
3874
- model is converted to a recursive map to ensure compatibility
3875
- with future Axon versions and the parameters are serialized
3876
- using `Nx.serialize/2`. There is some additional metadata included
3877
- such as current serialization version for compatibility.
3878
-
3879
- Serialization `opts` are forwarded to `Nx.serialize/2` and
3880
- `:erlang.term_to_binary/2` for controlling compression options.
3881
-
3882
- ## Examples
3883
-
3884
- iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
3885
- iex> {init_fn, _} = Axon.build(model)
3886
- iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
3887
- iex> serialized = Axon.serialize(model, params)
3888
- iex> {saved_model, saved_params} = Axon.deserialize(serialized)
3889
- iex> {_, predict_fn} = Axon.build(saved_model)
3890
- iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
3891
- #Nx.Tensor<
3892
- f32[1][1]
3893
- [
3894
- [0.0]
3895
- ]
3896
- >
3897
-
3898
- """
3899
- @ doc type: :model
3900
- def serialize ( % Axon { output: id , nodes: nodes } , params , opts \\ [ ] ) do
3901
- Logger . warning (
3902
- "Attempting to serialize an Axon model. Serialization is discouraged" <>
3903
- " and will be deprecated, then removed in future releases. You should" <>
3904
- " keep your model definitions as code and serialize your parameters using" <>
3905
- " `Nx.serialize/2`."
3906
- )
3907
-
3908
- nodes =
3909
- Map . new ( nodes , fn { k , % { op: op , op_name: op_name } = v } ->
3910
- validate_serialized_op! ( op_name , op )
3911
- node_meta = Map . from_struct ( v )
3912
- { k , Map . put ( node_meta , :node , :node ) }
3913
- end )
3914
-
3915
- model_meta = % { output: id , nodes: nodes , axon: :axon }
3916
- params = Nx . serialize ( params , opts )
3917
- :erlang . term_to_binary ( { @ file_version , model_meta , params } , opts )
3918
- end
3919
-
3920
- # TODO: Raise on next release
3921
- defp validate_serialized_op! ( op_name , op ) when is_function ( op ) do
3922
- fun_info = Function . info ( op )
3923
-
3924
- case fun_info [ :type ] do
3925
- :local ->
3926
- Logger . warning (
3927
- "Attempting to serialize anonymous function in #{ inspect ( op_name ) } layer," <>
3928
- " this will result in errors during deserialization between" <>
3929
- " different processes, and will be unsupported in a future" <>
3930
- " release. You should instead use a fully-qualified MFA function" <>
3931
- " such as &Axon.Layers.dense/3"
3932
- )
3933
-
3934
- { :type , :external } ->
3935
- :ok
3936
- end
3937
- end
3938
-
3939
- defp validate_serialized_op! ( _name , op ) when is_atom ( op ) , do: :ok
3940
-
3941
- @ doc """
3942
- Deserializes serialized model and parameters into a `{model, params}`
3943
- tuple.
3944
-
3945
- It is the opposite of `Axon.serialize/3`.
3946
-
3947
- ## Examples
3948
-
3949
- iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
3950
- iex> {init_fn, _} = Axon.build(model)
3951
- iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
3952
- iex> serialized = Axon.serialize(model, params)
3953
- iex> {saved_model, saved_params} = Axon.deserialize(serialized)
3954
- iex> {_, predict_fn} = Axon.build(saved_model)
3955
- iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
3956
- #Nx.Tensor<
3957
- f32[1][1]
3958
- [
3959
- [0.0]
3960
- ]
3961
- >
3962
-
3963
- """
3964
- @ doc type: :model
3965
- def deserialize ( serialized , opts \\ [ ] ) do
3966
- Logger . warning (
3967
- "Attempting to deserialize a serialized Axon model. Deserialization" <>
3968
- " is discouraged and will be deprecated, then removed in future" <>
3969
- " releases. You should keep your model definitions as code and" <>
3970
- " serialize your parameters using `Nx.serialize/2`."
3971
- )
3972
-
3973
- { 1 , model_meta , serialized_params } = :erlang . binary_to_term ( serialized , opts )
3974
- % { nodes: nodes , output: id } = model_meta
3975
-
3976
- nodes =
3977
- Map . new ( nodes , fn { k , % { op_name: op_name , op: op } = v } ->
3978
- validate_deserialized_op! ( op_name , op )
3979
-
3980
- node_struct =
3981
- v
3982
- |> Map . delete ( :node )
3983
- |> then ( & struct ( Axon.Node , & 1 ) )
3984
-
3985
- { k , node_struct }
3986
- end )
3987
-
3988
- model = % Axon { output: id , nodes: nodes }
3989
- params = Nx . deserialize ( serialized_params , opts )
3990
- { model , params }
3991
- end
3992
-
3993
- # TODO: Raise on next release
3994
- defp validate_deserialized_op! ( op_name , op ) when is_function ( op ) do
3995
- fun_info = Function . info ( op )
3996
-
3997
- case fun_info [ :type ] do
3998
- :local ->
3999
- Logger . warning (
4000
- "Attempting to deserialize anonymous function in #{ inspect ( op_name ) } layer," <>
4001
- " this will result in errors during deserialization between" <>
4002
- " different processes, and will be unsupported in a future" <>
4003
- " release"
4004
- )
4005
-
4006
- :external ->
4007
- unless function_exported? ( fun_info [ :module ] , fun_info [ :name ] , fun_info [ :arity ] ) do
4008
- Logger . warning (
4009
- "Attempting to deserialize model which depends on function" <>
4010
- " #{ inspect ( op ) } in layer #{ inspect ( op_name ) } which does not exist in" <>
4011
- " the current environment, check your dependencies"
4012
- )
4013
- end
4014
- end
4015
- end
4016
-
4017
- defp validate_deserialized_op! ( op , _op_name ) when is_atom ( op ) , do: :ok
4018
-
4019
3842
## Helpers
4020
3843
4021
3844
@ valid_initializers [ :zeros , :ones , :uniform , :normal , :identity ] ++
0 commit comments