Skip to content

Commit b4fe66b

Browse files
committed
Add documentation for loadmodel!
1 parent dee5842 commit b4fe66b

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed

docs/src/saving.md

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
You may wish to save models so that they can be loaded and run in a later
44
session. The easiest way to do this is via
5-
[BSON.jl](https://github.com/MikeInnes/BSON.jl).
5+
[BSON.jl](https://github.com/JuliaIO/BSON.jl).
66

77
Save a model:
88

@@ -46,15 +46,17 @@ versions of Flux).
4646

4747
!!! note
4848

49-
If a saved model's weights are stored on the GPU, the model will not load
49+
If a saved model's parameters are stored on the GPU, the model will not load
5050
later on if there is no GPU support available. It's best to [move your model
5151
to the CPU](gpu.md) with `cpu(model)` before saving it.
5252

53-
## Saving Model Weights
53+
!!! warning
5454

55-
In some cases it may be useful to save only the model parameters themselves, and
56-
rebuild the model architecture in your code. You can use `params(model)` to get
57-
model parameters.
55+
Previous versions of Flux suggested saving only the model weights using
56+
`@save "mymodel.bson" params(model)`.
57+
This is no longer recommended and even strongly discouraged.
58+
Saving models this way will only store the trainable parameters which
59+
will result in incorrect behavior for layers like `BatchNorm`.
5860

5961
```Julia
6062
julia> using Flux
@@ -64,28 +66,29 @@ Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
6466

6567
julia> weights = Flux.params(model);
6668

67-
julia> using BSON: @save
68-
69-
julia> @save "mymodel.bson" weights
70-
```
71-
72-
You can easily load parameters back into a model with `Flux.loadparams!`.
69+
Loading the model as shown above will return a new model with the stored parameters.
70+
But sometimes you already have a model, and you want to load stored parameters into it.
71+
This can be done as
7372

7473
```julia
75-
julia> using Flux
74+
using Flux: loadmodel!
75+
using BSON: @load
7676
77-
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
78-
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
77+
# some predefined model
78+
model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
7979
80-
julia> using BSON: @load
80+
# load one model into another
81+
model = loadmodel!(model, @load("mymodel.bson"))
82+
```
8183

82-
julia> @load "mymodel.bson" weights
84+
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.
8385

84-
julia> Flux.loadparams!(model, weights)
86+
```@docs
87+
Flux.loadmodel!
88+
Flux.isloadleaf
89+
Flux.loadleaf!
8590
```
8691

87-
The new `model` we created will now be identical to the one we saved parameters for.
88-
8992
## Checkpointing
9093

9194
In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).

src/loading.jl

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1+
"""
2+
isloadleaf(x)
3+
4+
Return `true` whenever `x` should be treated as a "leaf node"
5+
for the purposes of loading parameters.
6+
By default, `isloadleaf` returns `true` if [`Functors.isleaf`](@ref)
7+
is `true` for all [`Functors.children(x)`](@ref `Functors.children`).
8+
9+
You can override this function for a specific type if needed.
10+
"""
111
isloadleaf(x) = all(Functors.isleaf, Functors.children(x))
212

3-
loadnumeric!(x, x̄, err) = x
4-
loadnumeric!(x::Zeros, x̄, err) = x
5-
function loadnumeric!(x::AbstractArray, x̄::AbstractArray, err)
13+
"""
14+
loadleaf!(x, x̄, err)
15+
16+
Copy `x̄` to `x` or throw `err` when their sizes are mismatched.
17+
By default, use `copyto!` when `x` and `x̄` are arrays.
18+
Otherwise, just return `x`.
19+
"""
20+
loadleaf!(x, x̄, err) = x
21+
loadleaf!(x::Zeros, x̄, err) = x
22+
function loadleaf!(x::AbstractArray, x̄::AbstractArray, err)
623
(size(x) == size(x̄)) || throw(err)
724
copyto!(x, x̄)
825
end
@@ -14,7 +31,7 @@ function _loadto!(m, m̄)
1431
throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match."))
1532

1633
err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")
17-
foreach((l, l̄) -> loadnumeric!(l, l̄, err), ls, l̄s)
34+
foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s)
1835

1936
return m
2037
end
@@ -23,6 +40,27 @@ function loadto!(m::T, m̄::S) where {T, S}
2340
_loadto!(m, m̄)
2441
end
2542

43+
"""
44+
loadmodel!(m, m̄)
45+
46+
Copy all the parameters (trainable and non-trainable) from `m̄` to `m`.
47+
48+
`loadmodel!` recursively walks `m` and `m̄` until it encounters
49+
a subfield, `x`, (i.e. layer) where `isloadleaf(x)` is true.
50+
The parameters of the matching subfield, `x̄`, are copied to `x`,
51+
throwing an error whenever:
52+
- `x` and `x̄` are not the same type (e.g. loading a `Conv` to a `Dense`)
53+
- `x` and `x̄` do not share the same fields
54+
- the parameter sizes are mismatched between `x` and `x̄`
55+
56+
See [`loadleaf!`](@ref) for more details on the copy behavior.
57+
See [`isloadleaf`](@ref) for more details on which layers are considered leaves.
58+
59+
!!! warning
60+
This function allows `m̄` to be a vector or `Params` for backwards-compatibility.
61+
You should avoid using `loadmodel!` this way, because it skips most of the structural
62+
checking used when `m̄` is also a struct. Silent errors may occur.
63+
"""
2664
function loadmodel!(m, xs::Params)
2765
for (p, x) in zip(params(m), xs)
2866
size(p) == size(x) ||

0 commit comments

Comments
 (0)