Skip to content

Commit 132bdd8

Browse files
mcabbottdarsnack
andauthored
Wrap optimiser and state in a struct (#30)
* wrap optimiser and state in a struct * rename state -> setup * remove _update and use dispatch * an idea * fixup rebase * rm opt(state, model, ...) methods * move show, tweak things * docs * add a check that there are no tied weights * a bug * a test * include strings * add docstrings for all the basic functions * wording Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
1 parent 7da717b commit 132bdd8

File tree

6 files changed

+208
-111
lines changed

6 files changed

+208
-111
lines changed

docs/src/api.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
## Optimisation Rules
3+
14
```@docs
25
Optimisers.Descent
36
Optimisers.Momentum
@@ -15,9 +18,27 @@ Optimisers.ADAMW
1518
Optimisers.AdaBelief
1619
```
1720

21+
In addition to the main course, you may wish to order some of these condiments:
22+
1823
```@docs
1924
Optimisers.ClipGrad
2025
Optimisers.ClipNorm
2126
Optimisers.WeightDecay
2227
Optimisers.OptimiserChain
2328
```
29+
30+
## Model Interface
31+
32+
```@docs
33+
Optimisers.setup
34+
Optimisers.update
35+
Optimisers.update!
36+
```
37+
38+
## Rule Definition
39+
40+
```@docs
41+
Optimisers.apply!
42+
Optimisers.init
43+
Optimisers.@..
44+
```

docs/src/index.md

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,56 @@
22

33
## Define an Optimiser
44

5+
A new optimiser must overload two functions, `apply!` and `init`:
6+
57
```julia
6-
# Define a container to hold any optimiser specific parameters (if any)
7-
struct Descent{T}
8+
# Define a container to hold any optimiser specific parameters (if any):
9+
struct DecayDescent{T}
810
η::T
911
end
1012

11-
# Define an `apply!` rule with which to update the current params
12-
# using the gradients
13-
function Optimisers.apply!(o::Descent, state, m, m̄)
14-
o.η .* m̄, state
13+
# Define an `apply!` rule which encodes how the gradients will be used to
14+
# update the parameters:
15+
function Optimisers.apply!(o::DecayDescent, state, x, x̄)
16+
newx̄ = (o.η / state) .*
17+
nextstate = state + 1
18+
return nextstate, newx̄
1519
end
1620

17-
Optimisers.init(o, x::AbstractArray) = nothing
21+
# Define the function which sets up the initial state (if any):
22+
Optimisers.init(o::DecayDescent, x::AbstractArray) = 1
1823
```
1924

25+
The parameters will be immediately updated to `x .- newx̄`, while `nextstate` is
26+
caried to the next iteration.
27+
2028
Notice that the state is handled separately from the optimiser itself. This
2129
is a key design principle and allows users to manage their own state explicitly.
2230

2331
It of course also makes it easier to store the state.
2432

2533
## Usage
2634

35+
To apply such an optimiser to a whole model, `setup` builds a tree containing any initial
36+
state for every trainable array. Then at each step, `update` uses this and the gradient
37+
to adjust the model:
38+
2739
```julia
2840

2941
using Flux, Metalhead, Optimisers
3042

3143
o = Optimisers.ADAM() # define an ADAM optimiser with default settings
32-
st = Optimisers.state(o, m) # initialize the optimiser before using it
44+
st = Optimisers.setup(o, m) # initialize the optimiser before using it
3345

34-
model = ResNet() # define a model to train on
46+
model = ResNet18() # define a model to train on
3547
ip = rand(Float32, 224, 224, 3, 1) # dummy data
3648

3749
m̄, _ = gradient(model, ip) do m, x # calculate the gradients
38-
sum(m(x))
50+
sum(m(x)) # dummy loss function
3951
end
4052

53+
st, mnew = Optimisers.update(st, m, m̄)
4154

42-
st, mnew = Optimisers.update(o, st, m, m̄)
43-
44-
# or
45-
46-
st, mnew = o(m, m̄, st)
4755
```
4856

4957
Notice that a completely new instance of the model is returned. Internally, this
@@ -53,6 +61,6 @@ work with different forms of gradients, but most likely use case are the gradien
5361
returned by [Zygote.jl](https://fluxml.ai/Zygote.jl).
5462

5563
There is also `Optimisers.update!` which similarly returns a new model and new state,
56-
but is free to mutate arrays within the old one, for efficiency.
64+
but is free to mutate arrays within the old one for efficiency.
5765
The method of `apply!` you write is likewise free to mutate arrays within its state;
5866
they are defensively copied when this rule is used with `update`.

src/Optimisers.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,97 @@ export Descent, ADAM, Momentum, Nesterov, RMSProp,
1010
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
1111
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
1212

13+
"""
14+
Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)
15+
16+
This defines the action of any optimisation rule. It should return the modified gradient
17+
which will be subtracted from the parameters, and the updated state (if any) for use at
18+
the next iteration, as a tuple `(state, gradient)`.
19+
20+
For efficiency it is free to mutate the old state, but only what is returned will be used.
21+
Ideally this should check `iswriteable(x)`, which the built-in rules do via [`@..`](@ref).
22+
23+
The initial state is `init(rule::RuleType, parameters)`.
24+
25+
# Example
26+
```jldoctest
27+
julia> Optimisers.init(Descent(0.1), [1,2,3]) === nothing
28+
true
29+
30+
julia> Optimisers.apply!(Descent(0.1), nothing, [1,2,3], [4,5,6])
31+
(nothing, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(*, ([4, 5, 6], 0.1)))
32+
```
33+
"""
34+
apply!
35+
36+
"""
37+
Optimisers.init(rule::RuleType, parameters) -> state
38+
39+
Sets up the initial state for a given optimisation rule, and an array of parameters.
40+
This and [`apply!`](@ref) are the two functions which any new optimisation rule must define.
41+
42+
# Examples
43+
```jldoctest
44+
julia> Optimisers.init(Descent(), [1,2,3]) # is `nothing`
45+
46+
julia> Optimisers.init(Momentum(), [1.0, 2.0])
47+
2-element Vector{Float64}:
48+
0.0
49+
0.0
50+
```
51+
"""
52+
init
53+
54+
"""
55+
Optimisers.setup(rule, model) -> tree
56+
57+
Initialises the given optimiser for every trainable parameter within the model.
58+
Returns a tree of the relevant states, which must be passed to [`update`](@ref)
59+
or [`update!`](@ref).
60+
61+
# Example
62+
```jldoctest
63+
julia> Optimisers.setup(Descent(0.1f0), (x = rand(3), y = (true, false), z = tanh))
64+
(x = Leaf(Descent{Float32}(0.1), nothing), y = (nothing, nothing), z = nothing)
65+
```
66+
"""
67+
setup
68+
69+
"""
70+
Optimisers.update(tree, model, gradient) -> (tree, model)
71+
72+
Uses the optimiser and the gradient to change the trainable parameters in the model.
73+
Returns the improved model, and the optimiser states needed for the next update.
74+
The initial tree of states comes from [`setup`](@ref).
75+
76+
See also [`update!`](@ref), which will be faster for models of ordinary `Array`s or `CuArray`s.
77+
78+
# Example
79+
```jldoctest
80+
julia> m = (x = Float32[1,2,3], y = tanh);
81+
82+
julia> t = Optimisers.setup(Descent(0.1f0), m)
83+
(x = Leaf(Descent{Float32}(0.1), nothing), y = nothing)
84+
85+
julia> g = (x = [1,1,1], y = nothing); # fake gradient
86+
87+
julia> Optimisers.update(t, m, g)
88+
((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing), (x = Float32[0.9, 1.9, 2.9], y = tanh))
89+
```
90+
"""
91+
update
92+
93+
"""
94+
Optimisers.update!(tree, model, gradient) -> (tree, model)
95+
96+
Uses the optimiser and the gradient to change the trainable parameters in the model.
97+
Returns the improved model, and the optimiser states needed for the next update.
98+
The initial tree of states comes from [`setup`](@ref).
99+
100+
This is used in exactly the same manner as [`update`](@ref), but because it may mutate
101+
arrays within the old model (and the old state), it will be faster for models of ordinary
102+
`Array`s or `CuArray`s. However, you should not rely on the old model being fully updated.
103+
"""
104+
update!
105+
13106
end # module

src/interface.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,48 @@
11

2-
function state(o, x)
2+
struct Leaf{R,S}
3+
rule::R
4+
state::S
5+
end
6+
7+
function setup(rule, x; seen = Base.IdSet())
38
if isnumeric(x)
4-
return init(o, x)
9+
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
10+
isbits(x) || push!(seen, x)
11+
return Leaf(rule, init(rule, x))
512
elseif isleaf(x)
613
return nothing
714
else
815
x′, _ = functor(x)
9-
return map(xᵢ -> state(o, xᵢ), x′)
16+
return map(xᵢ -> setup(rule, xᵢ; seen), x′)
1017
end
1118
end
1219

13-
patch!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄)
20+
subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄)
1421

15-
function _update!(o, st, x, x̄s...)
16-
st′, x̄′ = apply!(o, st, x, x̄s...)
17-
return st′, patch!(x, x̄′)
22+
function update!(ℓ::Leaf, x, x̄s...)
23+
if all(isnothing, x̄s)
24+
return ℓ, x
25+
else
26+
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, x̄s...)
27+
return Leaf(ℓ.rule, s′), subtract!(x, x̄′)
28+
end
1829
end
1930

20-
function update!(o, state, x, x̄s...)
31+
function update!(tree, x, x̄s...)
2132
if all(isnothing, x̄s)
22-
return state, x
23-
elseif isnumeric(x)
24-
return _update!(o, state, x, x̄s...)
33+
return tree, x
2534
else
2635
x̄s′ = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
2736
x′, re = functor(typeof(x), x)
28-
xstate = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(o, stᵢ, xᵢ, x̄sᵢ...), state, x′, x̄s′...)
29-
return map(first, xstate), re(map(last, xstate))
37+
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
38+
return map(first, xtree), re(map(last, xtree))
3039
end
3140
end
3241

33-
function update(o, state, x, x̄s...)
34-
state= fmap(copy, state; exclude = iswriteable)
42+
function update(tree, x, x̄s...)
43+
t= fmap(copy, tree; exclude = iswriteable)
3544
x′ = fmap(copy, x; exclude = iswriteable)
36-
update!(o, state′, x′, x̄s...)
45+
update!(t′, x′, x̄s...)
3746
end
3847

3948
# default all rules to first order calls
@@ -75,3 +84,11 @@ function lazy end
7584
Broadcast.broadcasted(::typeof(lazy), x) = Lazy(x)
7685
struct Lazy{T}; bc::T; end
7786
Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
87+
88+
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
89+
ioc = IOContext(io, :compact => true)
90+
print(ioc, "Leaf(", ℓ.rule, ", ")
91+
show(ioc, ℓ.state)
92+
print(io, ")")
93+
end
94+

0 commit comments

Comments
 (0)