Skip to content

Commit 444a6b9

Browse files
mcabbottToucheSir
andauthored
Add adjust! (#113)
* add mutating adjust * add tests * add to docs * Apply suggestions from code review Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Brian Chen <[email protected]>
1 parent 36bebc5 commit 444a6b9

File tree

4 files changed

+94
-9
lines changed

4 files changed

+94
-9
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Optimisers.OptimiserChain
3434
Optimisers.setup
3535
Optimisers.update
3636
Optimisers.update!
37+
Optimisers.adjust!
3738
Optimisers.adjust(::Any, ::Real)
3839
Optimisers.freeze!
3940
Optimisers.thaw!

docs/src/index.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,26 @@ Optimisers.thaw!(opt)
165165
opt.layers[3].bias # Leaf(Momentum(...), [0.0, 0.0])
166166
```
167167

168+
## Adjusting Hyperparameters
169+
170+
To change the learning rate during training, use [`adjust!`](@ref Optimisers.adjust!).
171+
This works much like `freeze!` by mutating the state tree, or part of it,
172+
without discarding the momenta. For the Flux model from just above:
173+
174+
```julia
175+
Optimisers.adjust!(opt, 0.03) # change η for the whole model...
176+
177+
Optimisers.adjust!(opt.layers[3], 0.04) # ... or just for one layer.
178+
```
179+
180+
To change other fields of the optimisation rule, it accepts keyword arguments:
181+
182+
```julia
183+
Momentum |> fieldnames # (:eta, :rho)
184+
185+
Optimisers.adjust!(opt, rho = 0.95) # change ρ for the whole model.
186+
```
187+
168188
## Tied Parameters
169189

170190
If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
@@ -187,7 +207,7 @@ This identification relies on `===`, and will work for ordinary `Array`s and `Cu
187207
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
188208
from StaticArrays.jl.
189209

190-
210+
191211
## Obtaining a flat parameter vector
192212

193213
Instead of a nested tree-like structure, sometimes is is convenient to have all the

src/adjust.jl

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,15 @@ thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError(
5656
###
5757

5858
"""
59-
Optimisers.adjust(tree, η) -> tree
59+
Optimisers.adjust!(tree, η)
6060
6161
Alters the state `tree = setup(rule, model)` to change the parameters of the
6262
optimisation rule, without destroying its stored state. Typically used mid-way
6363
through training.
6464
65+
Can be applied to part of a model, by acting only on the corresponding part
66+
of the state `tree`.
67+
6568
To change just the learning rate, provide a number `η::Real`.
6669
6770
# Example
@@ -76,11 +79,13 @@ julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # wit
7679
julia> st
7780
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
7881
79-
julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched
82+
julia> Optimisers.adjust!(st, 0.123) # change learning rate, stored momentum untouched
83+
84+
julia> st
8085
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
8186
```
8287
83-
To change other parameters, `adjust` also accepts keyword arguments matching the field
88+
To change other parameters, `adjust!` also accepts keyword arguments matching the field
8489
names of the optimisation rule's type.
8590
8691
```
@@ -97,15 +102,30 @@ julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
97102
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
98103
```
99104
"""
100-
adjust(tree, eta::Real) = map(st -> adjust(st, eta), tree)
101-
adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
105+
adjust!(tree, eta::Real) = foreach(st -> adjust!(st, eta), tree)
106+
adjust!(tree; kw...) = foreach(st -> adjust!(st; kw...), tree)
102107

103-
adjust(::Nothing, ::Real) = nothing
104-
adjust(::Nothing; kw...) = nothing
108+
adjust!(ℓ::Leaf, eta::Real) = (ℓ.rule = adjust(ℓ.rule, eta); nothing)
109+
adjust!(ℓ::Leaf; kw...) = (ℓ.rule = adjust(ℓ.rule; kw...); nothing)
105110

106111
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
107112
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)
108113

114+
"""
115+
adjust(tree, η) -> tree
116+
117+
Like [`adjust!`](@ref Optimisers.adjust), but returns a new tree instead of mutating the old one.
118+
"""
119+
function adjust(tree, eta::Real)
120+
t′ = fmap(copy, tree; exclude = maywrite) # same as used for update / update!
121+
adjust!(t′, eta)
122+
t′
123+
end
124+
function adjust(tree; kw...)
125+
t′ = fmap(copy, tree; exclude = maywrite)
126+
adjust!(t′; kw...)
127+
t′
128+
end
109129

110130
"""
111131
Optimisers.adjust(rule::RuleType, η::Real) -> rule

test/runtests.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ end
177177
@test eltype(s6[2].state[2]) == Float32
178178
end
179179

180-
@testset "adjusyting parameters" begin
180+
@testset "adjusting parameters, out-of-place" begin
181181
# Simple momentum:
182182
m == ([0.0], sin), γ = Float32[4,3,2])
183183
s = Optimisers.setup(Momentum(0.1, 0.9), m)
@@ -221,6 +221,50 @@ end
221221
@test sc2.γ.state[2][1] [0.1, 0.2, 0.2]
222222
end
223223

224+
@testset "adjusting parameters, in-place" begin
225+
# Simple momentum:
226+
m == ([0.0], sin), γ = Float32[4,3,2])
227+
s = Optimisers.setup(Momentum(0.1, 0.9), m)
228+
s1, m1 = Optimisers.update(s, m, (α = nothing, γ = [1,10,100],))
229+
@test m.γ .- m1.γ [0.1, 1, 10]
230+
@test s1.γ.rule.eta == 0.1
231+
@test s1.γ.state [0.1, 1, 10]
232+
233+
Optimisers.adjust!(s1, 0.2)
234+
@test s1.γ.rule.eta == 0.2
235+
@test s1.γ.rule.rho == 0.9
236+
@test s1.γ.state [0.1, 1, 10]
237+
@test s1.α[1].rule.eta == 0.2
238+
239+
Optimisers.adjust!(s1; eta=0.3, rho=0.7)
240+
@test s1.γ.rule.eta == 0.3
241+
@test s1.γ.rule.rho == 0.7
242+
@test s1.γ.state [0.1, 1, 10]
243+
@test s1.α[1].rule.rho == 0.7
244+
245+
_, m3 = Optimisers.update(s1, m, (α = nothing, γ = [1,10,100],))
246+
@test !(m.γ .- m3.γ [1, 10, 100])
247+
248+
Optimisers.adjust!(s1, zeta = "this does nothing")
249+
@test s1.γ.rule.eta == 0.3
250+
251+
# OptimiserChain
252+
sc = Optimisers.setup(OptimiserChain(ClipGrad(2), Adam()), m)
253+
sc1, mc1 = Optimisers.update(sc, m, (α = nothing, γ = [1,10,100],))
254+
@test sc1.γ.rule.opts[2].eta == 0.001f0
255+
@test sc1.γ.state[2][1] [0.1, 0.2, 0.2]
256+
257+
Optimisers.adjust!(sc1, 0.2)
258+
@test sc1.γ.rule.opts[1].delta == 2 # unchanged
259+
@test sc1.γ.rule.opts[2].eta === 0.2f0
260+
@test sc1.γ.state[2][1] [0.1, 0.2, 0.2]
261+
262+
Optimisers.adjust!(sc1; delta = 2.5) # ClipGrad(2) does not store an Int, for this reason
263+
@test sc1.γ.rule.opts[1].delta == 2.5
264+
@test sc1.γ.rule.opts[2].eta === 0.2f0 # unchanged
265+
@test sc1.γ.state[2][1] [0.1, 0.2, 0.2]
266+
end
267+
224268
@testset "freeze/thaw" begin
225269
m = (x=[1.0, 2.0], y=([3.0, 4.0], sin));
226270
st = Optimisers.setup(Descent(0.1), m);

0 commit comments

Comments
 (0)