Skip to content

Commit 5ac2d6f

Browse files
authored
Use eltype(x) everywhere, ignore typeof(η) (#151)
* always convert learning rate to eltype(momentum) before use * don't parameterise rule structs, just use Float64 * fix tests like === 0.2f0 * fix a constructor * fix AdamW * fix Rprop * use a macro to define structs with default values * use T = eltype(x) * a few more structs * more structs * fix tests * doc fixes * fix docstrings * skip Yota on nightly * docstrings * breaking change, v0.3-dev * print Adam(0.01f0) without 0.009999999776482582 * Revert "skip Yota on nightly" This reverts commit abf0e13. * don't accidentally write to stdout
1 parent 322a6bb commit 5ac2d6f

File tree

7 files changed

+214
-178
lines changed

7 files changed

+214
-178
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.20"
4+
version = "0.3.0-DEV"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ These act on one array of parameters:
77

88
```julia
99
# Define a container to hold any optimiser specific parameters (if any):
10-
struct DecayDescent{T} <: Optimisers.AbstractRule
11-
eta::T
10+
struct DecayDescent <: Optimisers.AbstractRule
11+
eta::Float64
1212
end
1313

1414
# Define an `apply!` rule which encodes how the gradients will be used to
1515
# update the parameters:
1616
function Optimisers.apply!(o::DecayDescent, state, x, x̄)
17-
newx̄ = (o.eta / state) .*
17+
T = eltype(x)
18+
newx̄ = T(o.eta / state) .*
1819
nextstate = state + 1
1920
return nextstate, newx̄
2021
end

src/Optimisers.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ or [`update!`](@ref).
7878
julia> m = (x = rand(3), y = (true, false), z = tanh);
7979
8080
julia> Optimisers.setup(Momentum(), m) # same field names as m
81-
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
81+
(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
8282
```
8383
8484
The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
@@ -91,15 +91,15 @@ julia> struct Layer; mat; fun; end
9191
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
9292
9393
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
94-
(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
94+
(lay = (), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
9595
9696
julia> destructure(model)
9797
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
9898
9999
julia> using Functors; @functor Layer # annotate this type as containing parameters
100100
101101
julia> Optimisers.setup(Momentum(), model)
102-
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
102+
(lay = (mat = Leaf(Momentum(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
103103
104104
julia> destructure(model)
105105
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
@@ -120,13 +120,13 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
120120
```jldoctest
121121
julia> m = (x = Float32[1,2,3], y = tanh);
122122
123-
julia> t = Optimisers.setup(Descent(0.1f0), m)
124-
(x = Leaf(Descent{Float32}(0.1), nothing), y = ())
123+
julia> t = Optimisers.setup(Descent(0.1), m)
124+
(x = Leaf(Descent(0.1), nothing), y = ())
125125
126126
julia> g = (x = [1,1,1], y = nothing); # fake gradient
127127
128128
julia> Optimisers.update(t, m, g)
129-
((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
129+
((x = Leaf(Descent(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
130130
```
131131
"""
132132
update
@@ -152,7 +152,7 @@ julia> using StaticArrays, Zygote, Optimisers
152152
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
153153
154154
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
155-
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
155+
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))
156156
157157
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
158158
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])

src/adjust.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ julia> Optimisers.freeze!(s.x)
2323
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient
2424
2525
julia> m
26-
(x = ([1.0], 2.0), y = [-0.14159258336972558])
26+
(x = ([1.0], 2.0), y = [-0.14159265358979312])
2727
2828
julia> s
29-
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))
29+
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))
3030
3131
julia> Optimisers.thaw!(s)
3232
3333
julia> s.x
34-
(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ())
34+
(Leaf(Momentum(0.01, 0.9), [0.0]), ())
3535
```
3636
"""
3737
freeze!(tree) = foreach(freeze!, tree)
@@ -72,17 +72,17 @@ To change just the learning rate, provide a number `η::Real`.
7272
julia> m = (vec = rand(Float32, 2), fun = sin);
7373
7474
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
75-
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
75+
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
7676
7777
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
7878
7979
julia> st
80-
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
80+
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
8181
8282
julia> Optimisers.adjust!(st, 0.123) # change learning rate, stored momentum untouched
8383
8484
julia> st
85-
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
85+
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
8686
```
8787
8888
To change other parameters, `adjust!` also accepts keyword arguments matching the field
@@ -93,13 +93,13 @@ julia> fieldnames(Adam)
9393
(:eta, :beta, :epsilon)
9494
9595
julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
96-
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
96+
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
9797
9898
julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
99-
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
99+
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
100100
101101
julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
102-
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
102+
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
103103
```
104104
"""
105105
adjust!(tree, eta::Real) = foreach(st -> adjust!(st, eta), tree)

src/interface.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
77

88
abstract type AbstractRule end
99

10+
function Base.show(io::IO, rule::AbstractRule) # makes Adam(0.01f0) prettier
11+
invoke(show, Tuple{IO,Any}, IOContext(io, :compact => true), rule)
12+
end
13+
1014
###
1115
### setup
1216
###
@@ -225,3 +229,42 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
225229
onevalue::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
226230
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)
227231

232+
nonneg::Real) = η < 0 ? throw(DomainError(η, "the learning rate cannot be negative")) : η
233+
234+
"""
235+
@def struct Rule; eta = 0.1; beta = (0.7, 0.8); end
236+
237+
Helper macro for defining rules with default values.
238+
The types of the literal values are used in the `struct`,
239+
like this:
240+
```
241+
struct Rule
242+
eta::Float64
243+
beta::Tuple{Float64, Float64}
244+
Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
245+
end
246+
```
247+
Any field called `eta` is assumed to be a learning rate, and cannot be negative.
248+
"""
249+
macro def(expr)
250+
Meta.isexpr(expr, :struct) || throw("@def must act on a struct definition")
251+
lines = expr.args[3].args
252+
names, vals = [], []
253+
for i in eachindex(lines)
254+
lines[i] isa Symbol && throw("@def requires a default for every field")
255+
Meta.isexpr(lines[i], :(=)) || continue
256+
name, val = lines[i].args
257+
push!(names, name)
258+
push!(vals, val)
259+
lines[i] = :($name::$typeof($val))
260+
end
261+
rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2]
262+
check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing
263+
inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...))
264+
$check
265+
new($(names...))
266+
end)
267+
push!(lines, inner)
268+
esc(expr)
269+
end
270+

0 commit comments

Comments
 (0)