Skip to content

Commit 755a97a

Browse files
authored
Merge pull request #29 from mcabbott/isnumeric
Optimise only at `isnumeric` leaves
2 parents 9acfb38 + 4160122 commit 755a97a

File tree

3 files changed

+62
-37
lines changed

3 files changed

+62
-37
lines changed

src/interface.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
11
patch(x, x̄) = x .-
22

33
function state(o, x)
4-
if isleaf(x)
4+
if isnumeric(x)
55
return init(o, x)
6+
elseif isleaf(x)
7+
return nothing
68
else
7-
x, _ = functor(x)
8-
return map(x -> state(o, x), x)
9+
x, _ = functor(x)
10+
return map(xᵢ -> state(o, xᵢ), x)
911
end
1012
end
1113

1214
function _update(o, st, x, x̄s...)
13-
st, x̄ = apply(o, st, x, x̄s...)
14-
return st, patch(x, x̄)
15+
st, x̄ = apply(o, st, x, x̄s...)
16+
return st, patch(x, x̄)
1517
end
1618

1719
function update(o, state, x::T, x̄s...) where T
1820
if all(isnothing, x̄s)
1921
return state, x
20-
elseif isleaf(x)
22+
elseif isnumeric(x)
2123
return _update(o, state, x, x̄s...)
2224
else
23-
x̄s = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
24-
x, restructure = functor(typeof(x), x)
25-
xstate = map((state, x, x̄s...) -> update(o, state, x, x̄s...), state, x, x̄s...)
26-
return map(first, xstate), restructure(map(last, xstate))
25+
x̄s = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
26+
x′, re = functor(typeof(x), x)
27+
xstate = map((stᵢ, xᵢ, x̄sᵢ...) -> update(o, stᵢ, xᵢ, x̄sᵢ...), state, x, x̄s...)
28+
return map(first, xstate), re(map(last, xstate))
2729
end
2830
end
2931

3032
# default all rules to first order calls
3133
apply(o, state, x, dx, dxs...) = apply(o, state, x, dx)
34+
35+
isnumeric(x::AbstractArray{<:Number}) = isleaf(x) # isleaf to allow for e.g. transposed shared weights
36+
isnumeric(x::AbstractArray{<:Bool}) = false # convention of ChainRules is that Bool is non-differentiable
37+
isnumeric(x) = false

src/rules.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
108108
function apply(o::RMSProp, state, x, dx)
109109
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
110110
@. acc = ρ * acc + (1 - ρ) * dx^2
111-
dx = @. dx */ (sqrt(acc) + ϵ))
111+
dx = @. dx */ (sqrt(acc) + ϵ))
112112

113-
return acc, dx
113+
return acc, dx
114114
end
115115

116116
(o::RMSProp)(state, m, dm) = update(o, state, m, dm)
@@ -145,9 +145,9 @@ function apply(o::ADAM{T}, state, x, dx) where T
145145

146146
@. mt = β[1] * mt + (one(T) - β[1]) * dx
147147
@. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2
148-
dx = @. mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η
148+
dx = @. mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η
149149

150-
return (mt, vt, βt .* β), dx
150+
return (mt, vt, βt .* β), dx
151151
end
152152

153153
"""
@@ -185,12 +185,12 @@ function apply(o::RADAM, state, x, dx)
185185
ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
186186
if ρ > 4
187187
r = sqrt((ρ - 4) *- 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
188-
dx = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
188+
dx = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
189189
else
190-
dx = @. mt / (1 - βt[1]) * η
190+
dx = @. mt / (1 - βt[1]) * η
191191
end
192192

193-
return (mt, vt, βt .* β, t + 1), dx
193+
return (mt, vt, βt .* β, t + 1), dx
194194
end
195195

196196
"""
@@ -224,9 +224,9 @@ function apply(o::AdaMax, state, x, dx)
224224

225225
@. mt = β[1] * mt + (1 - β[1]) * dx
226226
@. ut = max(β[2] * ut, abs(dx))
227-
dx = @./(1 - βt[1])) * mt/(ut + ϵ)
227+
dx = @./(1 - βt[1])) * mt/(ut + ϵ)
228228

229-
return (mt, ut, βt .* β), dx
229+
return (mt, ut, βt .* β), dx
230230
end
231231

232232
"""
@@ -263,9 +263,9 @@ function apply(o::OADAM, state, x, dx)
263263
@. vt = β[2] * vt + (1 - β[2]) * dx^2
264264
@. dx = -dx_
265265
@. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
266-
dx = @. dx + 2*dx_
266+
dx = @. dx + 2*dx_
267267

268-
return (mt, vt, βt .* β, dx_), dx
268+
return (mt, vt, βt .* β, dx_), dx
269269
end
270270

271271
"""
@@ -296,9 +296,9 @@ function apply(o::ADAGrad, state, x, dx)
296296
acc = state
297297

298298
@. acc += dx^2
299-
dx = @. dx * η / (sqrt(acc) + ϵ)
299+
dx = @. dx * η / (sqrt(acc) + ϵ)
300300

301-
return acc, dx
301+
return acc, dx
302302
end
303303

304304
"""
@@ -330,10 +330,10 @@ function apply(o::ADADelta, state, x, dx)
330330
@. acc = ρ * acc + (1 - ρ) * dx^2
331331
# DON'T remove epsilon from numerator
332332
# or even out of the square roots
333-
dx = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
333+
dx = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
334334
@. Δacc = ρ * Δacc + (1 - ρ) * dx^2
335335

336-
return (acc, Δacc), dx
336+
return (acc, Δacc), dx
337337
end
338338

339339
"""
@@ -370,9 +370,9 @@ function apply(o::AMSGrad, state, x, dx)
370370
@. mt = β[1] * mt + (1 - β[1]) * dx
371371
@. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
372372
@. v̂t = max(v̂t, vt)
373-
dx = @. η * mt / (sqrt(v̂t) + ϵ)
373+
dx = @. η * mt / (sqrt(v̂t) + ϵ)
374374

375-
return (mt, vt, v̂t), dx
375+
return (mt, vt, v̂t), dx
376376
end
377377

378378
"""
@@ -407,10 +407,10 @@ function apply(o::NADAM, state, x, dx)
407407

408408
@. mt = β[1] * mt + (1 - β[1]) * dx
409409
@. vt = β[2] * vt + (1 - β[2]) * dx^2
410-
dx = @. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
410+
dx = @. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
411411
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η
412412

413-
return (mt, vt, βt .* β), dx
413+
return (mt, vt, βt .* β), dx
414414
end
415415

416416
"""
@@ -462,9 +462,9 @@ function apply(o::AdaBelief, state, x, dx)
462462

463463
@. mt = β[1] * mt + (1 - β[1]) * dx
464464
@. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
465-
dx = @. η * mt / (sqrt(st) + ϵ)
465+
dx = @. η * mt / (sqrt(st) + ϵ)
466466

467-
return (mt, st), dx
467+
return (mt, st), dx
468468
end
469469

470470
"""
@@ -485,9 +485,9 @@ init(o::WeightDecay, x::AbstractArray) = nothing
485485
(o::WeightDecay)(state, m, dm) = update(o, state, m, dm)
486486

487487
function apply(o::WeightDecay, state, x, dx)
488-
dx = @. dx + o.wd * x
488+
dx = @. dx + o.wd * x
489489

490-
return state, dx
490+
return state, dx
491491
end
492492

493493
"""

test/runtests.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,34 @@ using Statistics
88
@testset for o in (Descent(), ADAM(), Momentum(), Nesterov(), RMSProp(),
99
ADAGrad(), AdaMax(), ADADelta(), AMSGrad(), NADAM(),
1010
ADAMW(), RADAM(), OADAM(), AdaBelief())
11-
w == rand(3, 3), β = rand(3, 3))
11+
12+
# Original example
13+
w == 5rand(3, 3), β = rand(3, 3))
1214
st = Optimisers.state(o, w)
1315
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2)
14-
l = loss(w, w′)
16+
@test loss(w, w′) > 1
1517
for i = 1:10^4
1618
gs = gradient(x -> loss(x, w′), w)
17-
st, w = o(st, w, gs...)
19+
st, w = Optimisers.update(o, st, w, gs...)
20+
end
21+
lw = loss(w, w′)
22+
@test lw < 0.001
23+
24+
# Slightly harder variant
25+
m == randn(3), β = transpose(5rand(3,3)), γ = (rand(2), tanh)) # issue 28
26+
st = Optimisers.state(o, m)
27+
@test loss(m, w′) > 1
28+
for i = 1:10^4
29+
gs = gradient(x -> loss(x, w′), m)
30+
st, m = o(st, m, gs...)
31+
end
32+
lm = loss(m, w′)
33+
if lm < 0.1
34+
@test lm < 0.1
35+
else
36+
@test_broken lm < 0.1 # @test keyword broken doesn't exist on Julia 1.6
1837
end
19-
@test loss(w, w′) < 0.01
38+
2039
end
2140
end
2241

0 commit comments

Comments
 (0)