Skip to content

Commit 33c8144

Browse files
authored
Replace ADAM with Adam etc (#78)
* replace ADAM with Adam etc * also update names in tests
1 parent accfb00 commit 33c8144

File tree

6 files changed

+69
-59
lines changed

6 files changed

+69
-59
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.4"
4+
version = "0.2.5"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ It is initialised by `setup`, and then at each step, `update` returns both the n
3636
state, and the model with its trainable parameters adjusted:
3737

3838
```julia
39-
state = Optimisers.setup(Optimisers.ADAM(), model) # just once
39+
state = Optimisers.setup(Optimisers.Adam(), model) # just once
4040

4141
state, model = Optimisers.update(state, model, grad) # at every step
4242
```

src/Optimisers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ include("destructure.jl")
99
export destructure
1010

1111
include("rules.jl")
12-
export Descent, ADAM, Momentum, Nesterov, RMSProp,
13-
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
12+
export Descent, Adam, Momentum, Nesterov, RMSProp,
13+
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
1414
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
1515

1616
"""

src/rules.jl

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
@deprecate ADAM Adam
2+
@deprecate NADAM NAdam
3+
@deprecate ADAMW AdamW
4+
@deprecate RADAM RAdam
5+
@deprecate OADAM OAdam
6+
@deprecate ADAGrad AdaGrad
7+
@deprecate ADADelta AdaDelta
8+
19
"""
210
Descent(η = 1f-1)
311
@@ -110,9 +118,9 @@ function apply!(o::RMSProp, state, x, dx)
110118
end
111119

112120
"""
113-
ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
121+
Adam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
114122
115-
[ADAM](https://arxiv.org/abs/1412.6980) optimiser.
123+
[Adam](https://arxiv.org/abs/1412.6980) optimiser.
116124
117125
# Parameters
118126
- Learning rate (`η`): Amount by which gradients are discounted before updating
@@ -122,16 +130,18 @@ end
122130
- Machine epsilon (`ϵ`): Constant to prevent division by zero
123131
(no need to change default)
124132
"""
125-
struct ADAM{T}
133+
struct Adam{T}
126134
eta::T
127135
beta::Tuple{T, T}
128136
epsilon::T
129137
end
130-
ADAM= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = ADAM{typeof(η)}(η, β, ϵ)
138+
Adam= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = Adam{typeof(η)}(η, β, ϵ)
139+
140+
const Adam = Adam
131141

132-
init(o::ADAM, x::AbstractArray) = (zero(x), zero(x), o.beta)
142+
init(o::Adam, x::AbstractArray) = (zero(x), zero(x), o.beta)
133143

134-
function apply!(o::ADAM, state, x, dx)
144+
function apply!(o::Adam, state, x, dx)
135145
η, β, ϵ = o.eta, o.beta, o.epsilon
136146
mt, vt, βt = state
137147

@@ -143,9 +153,9 @@ function apply!(o::ADAM, state, x, dx)
143153
end
144154

145155
"""
146-
RADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
156+
RAdam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
147157
148-
[Rectified ADAM](https://arxiv.org/abs/1908.03265) optimizer.
158+
[Rectified Adam](https://arxiv.org/abs/1908.03265) optimizer.
149159
150160
# Parameters
151161
- Learning rate (`η`): Amount by which gradients are discounted before updating
@@ -155,16 +165,16 @@ end
155165
- Machine epsilon (`ϵ`): Constant to prevent division by zero
156166
(no need to change default)
157167
"""
158-
struct RADAM{T}
168+
struct RAdam{T}
159169
eta::T
160170
beta::Tuple{T, T}
161171
epsilon::T
162172
end
163-
RADAM= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = RADAM{typeof(η)}(η, β, ϵ)
173+
RAdam= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = RAdam{typeof(η)}(η, β, ϵ)
164174

165-
init(o::RADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, 1)
175+
init(o::RAdam, x::AbstractArray) = (zero(x), zero(x), o.beta, 1)
166176

167-
function apply!(o::RADAM, state, x, dx)
177+
function apply!(o::RAdam, state, x, dx)
168178
η, β, ϵ = o.eta, o.beta, o.epsilon
169179
ρ∞ = 2/(1-β[2])-1
170180

@@ -186,7 +196,7 @@ end
186196
"""
187197
AdaMax(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
188198
189-
[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of ADAM based on the ∞-norm.
199+
[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of Adam based on the ∞-norm.
190200
191201
# Parameters
192202
- Learning rate (`η`): Amount by which gradients are discounted before updating
@@ -217,10 +227,10 @@ function apply!(o::AdaMax, state, x, dx)
217227
end
218228

219229
"""
220-
OADAM(η = 1f-3, β = (5f-1, 9f-1), ϵ = eps(typeof(η)))
230+
OAdam(η = 1f-3, β = (5f-1, 9f-1), ϵ = eps(typeof(η)))
221231
222-
[OADAM](https://arxiv.org/abs/1711.00141) (Optimistic ADAM)
223-
is a variant of ADAM adding an "optimistic" term suitable for adversarial training.
232+
[OAdam](https://arxiv.org/abs/1711.00141) (Optimistic Adam)
233+
is a variant of Adam adding an "optimistic" term suitable for adversarial training.
224234
225235
# Parameters
226236
- Learning rate (`η`): Amount by which gradients are discounted before updating
@@ -230,16 +240,16 @@ is a variant of ADAM adding an "optimistic" term suitable for adversarial traini
230240
- Machine epsilon (`ϵ`): Constant to prevent division by zero
231241
(no need to change default)
232242
"""
233-
struct OADAM{T}
243+
struct OAdam{T}
234244
eta::T
235245
beta::Tuple{T, T}
236246
epsilon::T
237247
end
238-
OADAM= 1f-3, β = (5f-1, 9f-1), ϵ = eps(typeof(η))) = OADAM{typeof(η)}(η, β, ϵ)
248+
OAdam= 1f-3, β = (5f-1, 9f-1), ϵ = eps(typeof(η))) = OAdam{typeof(η)}(η, β, ϵ)
239249

240-
init(o::OADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))
250+
init(o::OAdam, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))
241251

242-
function apply!(o::OADAM, state, x, dx)
252+
function apply!(o::OAdam, state, x, dx)
243253
η, β, ϵ = o.eta, o.beta, o.epsilon
244254
mt, vt, βt, term = state
245255

@@ -253,9 +263,9 @@ function apply!(o::OADAM, state, x, dx)
253263
end
254264

255265
"""
256-
ADAGrad(η = 1f-1, ϵ = eps(typeof(η)))
266+
AdaGrad(η = 1f-1, ϵ = eps(typeof(η)))
257267
258-
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
268+
[AdaGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
259269
parameter specific learning rates based on how frequently it is updated.
260270
Parameters don't need tuning.
261271
@@ -265,15 +275,15 @@ Parameters don't need tuning.
265275
- Machine epsilon (`ϵ`): Constant to prevent division by zero
266276
(no need to change default)
267277
"""
268-
struct ADAGrad{T}
278+
struct AdaGrad{T}
269279
eta::T
270280
epsilon::T
271281
end
272-
ADAGrad= 1f-1, ϵ = eps(typeof(η))) = ADAGrad{typeof(η)}(η, ϵ)
282+
AdaGrad= 1f-1, ϵ = eps(typeof(η))) = AdaGrad{typeof(η)}(η, ϵ)
273283

274-
init(o::ADAGrad, x::AbstractArray) = onevalue(o.epsilon, x)
284+
init(o::AdaGrad, x::AbstractArray) = onevalue(o.epsilon, x)
275285

276-
function apply!(o::ADAGrad, state, x, dx)
286+
function apply!(o::AdaGrad, state, x, dx)
277287
η, ϵ = o.eta, o.epsilon
278288
acc = state
279289

@@ -284,9 +294,9 @@ function apply!(o::ADAGrad, state, x, dx)
284294
end
285295

286296
"""
287-
ADADelta(ρ = 9f-1, ϵ = eps(typeof(ρ)))
297+
AdaDelta(ρ = 9f-1, ϵ = eps(typeof(ρ)))
288298
289-
[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning
299+
[AdaDelta](https://arxiv.org/abs/1212.5701) is a version of AdaGrad adapting its learning
290300
rate based on a window of past gradient updates.
291301
Parameters don't need tuning.
292302
@@ -295,15 +305,15 @@ Parameters don't need tuning.
295305
- Machine epsilon (`ϵ`): Constant to prevent division by zero
296306
(no need to change default)
297307
"""
298-
struct ADADelta{T}
308+
struct AdaDelta{T}
299309
rho::T
300310
epsilon::T
301311
end
302-
ADADelta= 9f-1, ϵ = eps(typeof(ρ))) = ADADelta{typeof(ρ)}(ρ, ϵ)
312+
AdaDelta= 9f-1, ϵ = eps(typeof(ρ))) = AdaDelta{typeof(ρ)}(ρ, ϵ)
303313

304-
init(o::ADADelta, x::AbstractArray) = (zero(x), zero(x))
314+
init(o::AdaDelta, x::AbstractArray) = (zero(x), zero(x))
305315

306-
function apply!(o::ADADelta, state, x, dx)
316+
function apply!(o::AdaDelta, state, x, dx)
307317
ρ, ϵ = o.rho, o.epsilon
308318
acc, Δacc = state
309319

@@ -318,7 +328,7 @@ end
318328
"""
319329
AMSGrad(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
320330
321-
The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM
331+
The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the Adam
322332
optimiser. Parameters don't need tuning.
323333
324334
# Parameters
@@ -352,9 +362,9 @@ function apply!(o::AMSGrad, state, x, dx)
352362
end
353363

354364
"""
355-
NADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
365+
NAdam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
356366
357-
[NADAM](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of ADAM.
367+
[NAdam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of Adam.
358368
Parameters don't need tuning.
359369
360370
# Parameters
@@ -365,16 +375,16 @@ Parameters don't need tuning.
365375
- Machine epsilon (`ϵ`): Constant to prevent division by zero
366376
(no need to change default)
367377
"""
368-
struct NADAM{T}
378+
struct NAdam{T}
369379
eta::T
370380
beta::Tuple{T, T}
371381
epsilon::T
372382
end
373-
NADAM= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = NADAM{typeof(η)}(η, β, ϵ)
383+
NAdam= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = NAdam{typeof(η)}(η, β, ϵ)
374384

375-
init(o::NADAM, x::AbstractArray) = (zero(x), zero(x), o.beta)
385+
init(o::NAdam, x::AbstractArray) = (zero(x), zero(x), o.beta)
376386

377-
function apply!(o::NADAM, state, x, dx)
387+
function apply!(o::NAdam, state, x, dx)
378388
η, β, ϵ = o.eta, o.beta, o.epsilon
379389

380390
mt, vt, βt = state
@@ -388,9 +398,9 @@ function apply!(o::NADAM, state, x, dx)
388398
end
389399

390400
"""
391-
ADAMW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η)))
401+
AdamW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η)))
392402
393-
[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its
403+
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
394404
weight decay regularization.
395405
396406
# Parameters
@@ -402,14 +412,14 @@ weight decay regularization.
402412
- Machine epsilon (`ϵ`): Constant to prevent division by zero
403413
(no need to change default)
404414
"""
405-
ADAMW= 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) =
406-
OptimiserChain(ADAM{typeof(η)}(η, β, ϵ), WeightDecay{typeof(η)}(γ))
415+
AdamW= 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) =
416+
OptimiserChain(Adam{typeof(η)}(η, β, ϵ), WeightDecay{typeof(η)}(γ))
407417

408418
"""
409419
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = 1e-16)
410420
411421
The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known
412-
ADAM optimiser.
422+
Adam optimiser.
413423
414424
# Parameters
415425
- Learning rate (`η`): Amount by which gradients are discounted before updating

test/rules.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ Random.seed!(1)
66

77
RULES = [
88
# All the rules at default settings:
9-
Descent(), ADAM(), Momentum(), Nesterov(), RMSProp(),
10-
ADAGrad(), AdaMax(), ADADelta(), AMSGrad(), NADAM(),
11-
ADAMW(), RADAM(), OADAM(), AdaBelief(),
9+
Descent(), Adam(), Momentum(), Nesterov(), RMSProp(),
10+
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
11+
AdamW(), RAdam(), OAdam(), AdaBelief(),
1212
# A few chained combinations:
13-
OptimiserChain(WeightDecay(), ADAM(0.001)),
14-
OptimiserChain(ClipNorm(), ADAM(0.001)),
13+
OptimiserChain(WeightDecay(), Adam(0.001)),
14+
OptimiserChain(ClipNorm(), Adam(0.001)),
1515
OptimiserChain(ClipGrad(0.5), Momentum()),
16-
OptimiserChain(WeightDecay(), OADAM(), ClipGrad(1)),
16+
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),
1717
]
1818

1919
name(o) = typeof(o).name.name # just for printing testset headings
@@ -177,10 +177,10 @@ end
177177
@testset "with complex numbers: Flux#1776" begin
178178
empty!(LOG)
179179
@testset "$(name(opt))" for opt in [
180-
# The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
181-
ADAM(1e-2), RMSProp(1e-2), RADAM(1e-2), OADAM(1e-2), ADAGrad(1e-2), ADADelta(0.9, 1e-5), NADAM(1e-2), AdaBelief(1e-2),
180+
# The Flux PR had 1e-2 for all. But AdaDelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
181+
Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2),
182182
# These weren't in Flux PR:
183-
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), ADAMW(1e-2),
183+
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2),
184184
]
185185
# Our "model" is just a complex number
186186
model = (w = zeros(ComplexF64, 1),)

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
160160
ok = (1.0:3.0, sin, "abc", :abc)
161161
m == ok, β = rand(3), γ = ok)
162162
m1 = (rand(3), m, rand(3))
163-
@test Optimisers.setup(ADAMW(), m1) isa Tuple
163+
@test Optimisers.setup(AdamW(), m1) isa Tuple
164164
m2 = (rand(3), m, rand(3), m, rand(3)) # illegal
165-
@test_throws ArgumentError Optimisers.setup(ADAMW(), m2)
165+
@test_throws ArgumentError Optimisers.setup(AdamW(), m2)
166166
end
167167

168168
end

0 commit comments

Comments
 (0)