Skip to content

Commit 17102dc

Browse files
Put an abstract type over the optimisers to allow dispatch (#79)
* Put an abstract type over the optimisers to allow dispatch Required for SciML/Optimization.jl#255 * add depwarn, update docs, shorter name, version * change to AbstractRule * restrict types to Real too, while touching * one more AbstractRule * Revert "restrict types to Real too, while touching" This reverts commit 014cc44. Co-authored-by: Michael Abbott <[email protected]>
1 parent 5d3d741 commit 17102dc

File tree

5 files changed

+23
-19
lines changed

5 files changed

+23
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.5"
4+
version = "0.2.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ These act on one array of parameters:
77

88
```julia
99
# Define a container to hold any optimiser specific parameters (if any):
10-
struct DecayDescent{T}
10+
struct DecayDescent{T} <: Optimisers.AbstractRule
1111
η::T
1212
end
1313

src/Optimisers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Functors: functor, fmap, isleaf
44
using LinearAlgebra
55

66
include("interface.jl")
7+
export AbstractRule
78

89
include("destructure.jl")
910
export destructure

src/interface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ base(dx::Tangent) = backing(canonicalize(dx))
44
base(dx) = dx
55
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
66

7+
abstract type AbstractRule end
8+
79
struct Leaf{R,S}
810
rule::R
911
state::S
1012
end
1113

1214
function setup(rule, x; seen = Base.IdSet())
15+
rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup)
1316
if isnumeric(x)
1417
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
1518
isbits(x) || push!(seen, x)

src/rules.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
1616
- Learning rate (`η`): Amount by which gradients are discounted before updating
1717
the weights.
1818
"""
19-
struct Descent{T}
19+
struct Descent{T} <: AbstractRule
2020
eta::T
2121
end
2222
Descent() = Descent(1f-1)
@@ -40,7 +40,7 @@ Gradient descent optimizer with learning rate `η` and momentum `ρ`.
4040
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
4141
prominent direction, in effect dampening oscillations.
4242
"""
43-
struct Momentum{T}
43+
struct Momentum{T} <: AbstractRule
4444
eta::T
4545
rho::T
4646
end
@@ -66,7 +66,7 @@ Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
6666
- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the
6767
prominent direction, in effect dampening oscillations.
6868
"""
69-
struct Nesterov{T}
69+
struct Nesterov{T} <: AbstractRule
7070
eta::T
7171
rho::T
7272
end
@@ -104,7 +104,7 @@ gradients by an estimate their variance, instead of their second moment.
104104
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
105105
of the algorithm.
106106
"""
107-
struct RMSProp{T}
107+
struct RMSProp{T} <: AbstractRule
108108
eta::T
109109
rho::T
110110
epsilon::T
@@ -148,7 +148,7 @@ end
148148
- Machine epsilon (`ϵ`): Constant to prevent division by zero
149149
(no need to change default)
150150
"""
151-
struct Adam{T}
151+
struct Adam{T} <: AbstractRule
152152
eta::T
153153
beta::Tuple{T, T}
154154
epsilon::T
@@ -183,7 +183,7 @@ end
183183
- Machine epsilon (`ϵ`): Constant to prevent division by zero
184184
(no need to change default)
185185
"""
186-
struct RAdam{T}
186+
struct RAdam{T} <: AbstractRule
187187
eta::T
188188
beta::Tuple{T, T}
189189
epsilon::T
@@ -224,7 +224,7 @@ end
224224
- Machine epsilon (`ϵ`): Constant to prevent division by zero
225225
(no need to change default)
226226
"""
227-
struct AdaMax{T}
227+
struct AdaMax{T} <: AbstractRule
228228
eta::T
229229
beta::Tuple{T, T}
230230
epsilon::T
@@ -258,7 +258,7 @@ is a variant of Adam adding an "optimistic" term suitable for adversarial traini
258258
- Machine epsilon (`ϵ`): Constant to prevent division by zero
259259
(no need to change default)
260260
"""
261-
struct OAdam{T}
261+
struct OAdam{T} <: AbstractRule
262262
eta::T
263263
beta::Tuple{T, T}
264264
epsilon::T
@@ -293,7 +293,7 @@ Parameters don't need tuning.
293293
- Machine epsilon (`ϵ`): Constant to prevent division by zero
294294
(no need to change default)
295295
"""
296-
struct AdaGrad{T}
296+
struct AdaGrad{T} <: AbstractRule
297297
eta::T
298298
epsilon::T
299299
end
@@ -323,7 +323,7 @@ Parameters don't need tuning.
323323
- Machine epsilon (`ϵ`): Constant to prevent division by zero
324324
(no need to change default)
325325
"""
326-
struct AdaDelta{T}
326+
struct AdaDelta{T} <: AbstractRule
327327
rho::T
328328
epsilon::T
329329
end
@@ -357,7 +357,7 @@ optimiser. Parameters don't need tuning.
357357
- Machine epsilon (`ϵ`): Constant to prevent division by zero
358358
(no need to change default)
359359
"""
360-
struct AMSGrad{T}
360+
struct AMSGrad{T} <: AbstractRule
361361
eta::T
362362
beta::Tuple{T, T}
363363
epsilon::T
@@ -393,7 +393,7 @@ Parameters don't need tuning.
393393
- Machine epsilon (`ϵ`): Constant to prevent division by zero
394394
(no need to change default)
395395
"""
396-
struct NAdam{T}
396+
struct NAdam{T} <: AbstractRule
397397
eta::T
398398
beta::Tuple{T, T}
399399
epsilon::T
@@ -447,7 +447,7 @@ Adam optimiser.
447447
- Machine epsilon (`ϵ::Float32`): Constant to prevent division by zero
448448
(no need to change default)
449449
"""
450-
struct AdaBelief{T}
450+
struct AdaBelief{T} <: AbstractRule
451451
eta::T
452452
beta::Tuple{T, T}
453453
epsilon::T
@@ -479,7 +479,7 @@ This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to t
479479
# Parameters
480480
- Weight decay (`γ`): Decay applied to weights during optimisation.
481481
"""
482-
struct WeightDecay{T}
482+
struct WeightDecay{T} <: AbstractRule
483483
gamma::T
484484
end
485485
WeightDecay() = WeightDecay(5f-4)
@@ -499,7 +499,7 @@ Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
499499
500500
See also [`ClipNorm`](@ref).
501501
"""
502-
struct ClipGrad{T<:Real}
502+
struct ClipGrad{T<:Real} <: AbstractRule
503503
delta::T
504504
end
505505
ClipGrad() = ClipGrad(10f0)
@@ -524,7 +524,7 @@ which you can turn off with `throw = false`.
524524
525525
See also [`ClipGrad`](@ref).
526526
"""
527-
struct ClipNorm{T<:Real}
527+
struct ClipNorm{T<:Real} <: AbstractRule
528528
omega::T
529529
p::T
530530
throw::Bool
@@ -566,7 +566,7 @@ julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
566566
([-0.03, -0.1, -0.1],)
567567
```
568568
"""
569-
struct OptimiserChain{O<:Tuple}
569+
struct OptimiserChain{O<:Tuple} <: AbstractRule
570570
opts::O
571571
end
572572
OptimiserChain(opts...) = OptimiserChain(opts)

0 commit comments

Comments
 (0)