Skip to content

Commit 1ccbeba

Browse files
authored
Use model state struct instead of parameters (#553)
1 parent 19803a0 commit 1ccbeba

File tree

12 files changed

+1731
-1935
lines changed

12 files changed

+1731
-1935
lines changed

examples/vision/mnist.exs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
Mix.install([
2-
{:axon, "~> 0.5"},
2+
{:axon, path: "~/projects/axon"},
33
{:polaris, "~> 0.1"},
4-
{:exla, "~> 0.5"},
5-
{:nx, "~> 0.5"},
4+
{:exla, ">= 0.0.0"},
65
{:scidata, "~> 0.1"}
76
])
87

98
defmodule Mnist do
10-
require Axon
11-
129
@batch_size 32
1310
@image_side_pixels 28
1411
@channel_value_max 255

lib/axon.ex

Lines changed: 26 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,6 @@ defmodule Axon do
281281

282282
require Logger
283283

284-
# Axon serialization version
285-
@file_version 1
286-
287284
@type t :: %__MODULE__{}
288285

289286
defstruct [
@@ -417,16 +414,18 @@ defmodule Axon do
417414
}
418415
end
419416

420-
def param(name, shape, opts) when is_tuple(shape) or is_function(shape) do
421-
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32})
417+
def param(name, shape, opts) when is_binary(name) and (is_tuple(shape) or is_function(shape)) do
418+
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
422419
initializer = validate_initializer!(opts[:initializer])
423420
type = opts[:type] || {:f, 32}
421+
kind = opts[:kind] || :parameter
424422

425423
%Axon.Parameter{
426424
name: name,
427425
shape: shape,
428426
type: type,
429-
initializer: initializer
427+
initializer: initializer,
428+
kind: kind
430429
}
431430
end
432431

@@ -586,7 +585,7 @@ defmodule Axon do
586585
iex> inp1 = Axon.input("input_0", shape: {nil, 1})
587586
iex> inp2 = Axon.input("input_1", shape: {nil, 2})
588587
iex> model = Axon.container(%{a: inp1, b: inp2})
589-
iex> %{a: a, b: b} = Axon.predict(model, %{}, %{
588+
iex> %{a: a, b: b} = Axon.predict(model, Axon.ModelState.empty(), %{
590589
...> "input_0" => Nx.tensor([[1.0]]),
591590
...> "input_1" => Nx.tensor([[1.0, 2.0]])
592591
...> })
@@ -667,42 +666,6 @@ defmodule Axon do
667666
end
668667
end
669668

670-
@doc """
671-
Wraps an Axon model into a namespace.
672-
673-
A namespace is a part of an Axon model which is meant to
674-
be a self-contained collection of Axon layers. Namespaces
675-
are guaranteed to always generate with the same internal
676-
layer names and can be re-used universally across models.
677-
678-
Namespaces are most useful for containing large collections
679-
of layers and offering a straightforward means for accessing
680-
the parameters of individual model components. A common application
681-
of namespaces is to use them in with a pre-trained model for
682-
fine-tuning:
683-
684-
{base, resnet_params} = resnet()
685-
base = base |> Axon.namespace("resnet")
686-
687-
model = base |> Axon.dense(1)
688-
{init_fn, predict_fn} = Axon.build(model)
689-
690-
init_fn.(Nx.template({1, 3, 224, 224}, {:f, 32}), %{"resnset" => resnet_params})
691-
692-
Notice you can use `init_fn` in conjunction with namespaces
693-
to specify which portion of a model you'd like to initialize
694-
from a fixed starting point.
695-
696-
Namespaces have fixed names, which means it's easy to run into namespace
697-
collisions. Re-using namespaces, re-using inner parts of a namespace,
698-
and attempting to share layers between namespaces are still sharp
699-
edges in namespace usage.
700-
"""
701-
@doc type: :special
702-
def namespace(%Axon{} = axon, name) when is_binary(name) do
703-
layer(:namespace, [axon], name: name)
704-
end
705-
706669
@doc """
707670
Returns a function which represents a self-contained re-usable block
708671
of operations in a neural network. All parameters in the block are
@@ -1561,7 +1524,8 @@ defmodule Axon do
15611524
key_state =
15621525
param("key", fn _ -> {2} end,
15631526
type: {:u, 32},
1564-
initializer: fn _, _ -> Nx.Random.key(seed) end
1527+
initializer: fn _, _ -> Nx.Random.key(seed) end,
1528+
kind: :state
15651529
)
15661530

15671531
layer(dropout, [x, key_state],
@@ -1867,8 +1831,8 @@ defmodule Axon do
18671831
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
18681832
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])
18691833

1870-
mean = param("mean", mean_shape, initializer: :zeros)
1871-
var = param("var", var_shape, initializer: :ones)
1834+
mean = param("mean", mean_shape, initializer: :zeros, kind: :state)
1835+
var = param("var", var_shape, initializer: :ones, kind: :state)
18721836

18731837
layer(
18741838
norm,
@@ -3006,14 +2970,16 @@ defmodule Axon do
30062970
key_state =
30072971
param("key", fn _ -> {2} end,
30082972
type: {:u, 32},
3009-
initializer: fn _, _ -> Nx.Random.key(seed) end
2973+
initializer: fn _, _ -> Nx.Random.key(seed) end,
2974+
kind: :state
30102975
)
30112976

30122977
name =
30132978
case parent_name do
30142979
nil ->
30152980
fn _, op_counts ->
3016-
"lstm_#{op_counts[rnn_type]}_#{state_name}_hidden_state"
2981+
count = op_counts[rnn_type] || 0
2982+
"#{Atom.to_string(rnn_type)}_#{count}_#{state_name}_hidden_state"
30172983
end
30182984

30192985
parent_name when is_binary(parent_name) ->
@@ -3042,9 +3008,16 @@ defmodule Axon do
30423008

30433009
arity == 3 ->
30443010
fun =
3045-
fn inputs, key, _opts ->
3011+
fn inputs, key, opts ->
30463012
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)
3047-
initializer.(shape, {:f, 32}, key)
3013+
keys = Nx.Random.split(key)
3014+
out = initializer.(shape, {:f, 32}, keys[1])
3015+
3016+
if opts[:mode] == :train do
3017+
%Axon.StatefulOutput{output: out, state: %{"key" => keys[0]}}
3018+
else
3019+
out
3020+
end
30483021
end
30493022

30503023
{fun, [x, key_state]}
@@ -3168,6 +3141,7 @@ defmodule Axon do
31683141
the update process.
31693142
"""
31703143
@doc type: :model
3144+
@deprecated "Use Axon.ModelState.freeze/2 instead"
31713145
def freeze(model, fun_or_predicate \\ :all) do
31723146
freeze(model, fun_or_predicate, true)
31733147
end
@@ -3240,6 +3214,7 @@ defmodule Axon do
32403214
the update process.
32413215
"""
32423216
@doc type: :model
3217+
@deprecated "Use Axon.ModelState.freeze/2 instead"
32433218
def unfreeze(model, fun_or_predicate \\ :all) do
32443219
freeze(model, fun_or_predicate, false)
32453220
end
@@ -3410,7 +3385,7 @@ defmodule Axon do
34103385
out =
34113386
Nx.Defn.jit(
34123387
fn inputs ->
3413-
forward_fn.(init_fn.(inputs, %{}), inputs)
3388+
forward_fn.(init_fn.(inputs, Axon.ModelState.empty()), inputs)
34143389
end,
34153390
compiler: Axon.Defn
34163391
).(inputs)
@@ -3864,158 +3839,6 @@ defmodule Axon do
38643839
end
38653840
end
38663841

3867-
# Serialization
3868-
3869-
@doc """
3870-
Serializes a model and its parameters for persisting
3871-
models to disk or elsewhere.
3872-
3873-
Model and parameters are serialized as a tuple, where the
3874-
model is converted to a recursive map to ensure compatibility
3875-
with future Axon versions and the parameters are serialized
3876-
using `Nx.serialize/2`. There is some additional metadata included
3877-
such as current serialization version for compatibility.
3878-
3879-
Serialization `opts` are forwarded to `Nx.serialize/2` and
3880-
`:erlang.term_to_binary/2` for controlling compression options.
3881-
3882-
## Examples
3883-
3884-
iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
3885-
iex> {init_fn, _} = Axon.build(model)
3886-
iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
3887-
iex> serialized = Axon.serialize(model, params)
3888-
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
3889-
iex> {_, predict_fn} = Axon.build(saved_model)
3890-
iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
3891-
#Nx.Tensor<
3892-
f32[1][1]
3893-
[
3894-
[0.0]
3895-
]
3896-
>
3897-
3898-
"""
3899-
@doc type: :model
3900-
def serialize(%Axon{output: id, nodes: nodes}, params, opts \\ []) do
3901-
Logger.warning(
3902-
"Attempting to serialize an Axon model. Serialization is discouraged" <>
3903-
" and will be deprecated, then removed in future releases. You should" <>
3904-
" keep your model definitions as code and serialize your parameters using" <>
3905-
" `Nx.serialize/2`."
3906-
)
3907-
3908-
nodes =
3909-
Map.new(nodes, fn {k, %{op: op, op_name: op_name} = v} ->
3910-
validate_serialized_op!(op_name, op)
3911-
node_meta = Map.from_struct(v)
3912-
{k, Map.put(node_meta, :node, :node)}
3913-
end)
3914-
3915-
model_meta = %{output: id, nodes: nodes, axon: :axon}
3916-
params = Nx.serialize(params, opts)
3917-
:erlang.term_to_binary({@file_version, model_meta, params}, opts)
3918-
end
3919-
3920-
# TODO: Raise on next release
3921-
defp validate_serialized_op!(op_name, op) when is_function(op) do
3922-
fun_info = Function.info(op)
3923-
3924-
case fun_info[:type] do
3925-
:local ->
3926-
Logger.warning(
3927-
"Attempting to serialize anonymous function in #{inspect(op_name)} layer," <>
3928-
" this will result in errors during deserialization between" <>
3929-
" different processes, and will be unsupported in a future" <>
3930-
" release. You should instead use a fully-qualified MFA function" <>
3931-
" such as &Axon.Layers.dense/3"
3932-
)
3933-
3934-
{:type, :external} ->
3935-
:ok
3936-
end
3937-
end
3938-
3939-
defp validate_serialized_op!(_name, op) when is_atom(op), do: :ok
3940-
3941-
@doc """
3942-
Deserializes serialized model and parameters into a `{model, params}`
3943-
tuple.
3944-
3945-
It is the opposite of `Axon.serialize/3`.
3946-
3947-
## Examples
3948-
3949-
iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
3950-
iex> {init_fn, _} = Axon.build(model)
3951-
iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
3952-
iex> serialized = Axon.serialize(model, params)
3953-
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
3954-
iex> {_, predict_fn} = Axon.build(saved_model)
3955-
iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
3956-
#Nx.Tensor<
3957-
f32[1][1]
3958-
[
3959-
[0.0]
3960-
]
3961-
>
3962-
3963-
"""
3964-
@doc type: :model
3965-
def deserialize(serialized, opts \\ []) do
3966-
Logger.warning(
3967-
"Attempting to deserialize a serialized Axon model. Deserialization" <>
3968-
" is discouraged and will be deprecated, then removed in future" <>
3969-
" releases. You should keep your model definitions as code and" <>
3970-
" serialize your parameters using `Nx.serialize/2`."
3971-
)
3972-
3973-
{1, model_meta, serialized_params} = :erlang.binary_to_term(serialized, opts)
3974-
%{nodes: nodes, output: id} = model_meta
3975-
3976-
nodes =
3977-
Map.new(nodes, fn {k, %{op_name: op_name, op: op} = v} ->
3978-
validate_deserialized_op!(op_name, op)
3979-
3980-
node_struct =
3981-
v
3982-
|> Map.delete(:node)
3983-
|> then(&struct(Axon.Node, &1))
3984-
3985-
{k, node_struct}
3986-
end)
3987-
3988-
model = %Axon{output: id, nodes: nodes}
3989-
params = Nx.deserialize(serialized_params, opts)
3990-
{model, params}
3991-
end
3992-
3993-
# TODO: Raise on next release
3994-
defp validate_deserialized_op!(op_name, op) when is_function(op) do
3995-
fun_info = Function.info(op)
3996-
3997-
case fun_info[:type] do
3998-
:local ->
3999-
Logger.warning(
4000-
"Attempting to deserialize anonymous function in #{inspect(op_name)} layer," <>
4001-
" this will result in errors during deserialization between" <>
4002-
" different processes, and will be unsupported in a future" <>
4003-
" release"
4004-
)
4005-
4006-
:external ->
4007-
unless function_exported?(fun_info[:module], fun_info[:name], fun_info[:arity]) do
4008-
Logger.warning(
4009-
"Attempting to deserialize model which depends on function" <>
4010-
" #{inspect(op)} in layer #{inspect(op_name)} which does not exist in" <>
4011-
" the current environment, check your dependencies"
4012-
)
4013-
end
4014-
end
4015-
end
4016-
4017-
defp validate_deserialized_op!(op, _op_name) when is_atom(op), do: :ok
4018-
40193842
## Helpers
40203843

40213844
@valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++

0 commit comments

Comments
 (0)