Skip to content

Commit 07c16ee

Browse files
authored
Separate @lazy from @.. macro, fix some bugs (#48)
* change broadcasting macro & remove bugs * fix OADAM * change to explicit copy instead * replace similar with map * add warnings to docstring * one missing dot in OADAM * two bugs and a comment * make lazy broadcasting its own macro * more serious mutation test * fixup * name * name'
1 parent c73dea7 commit 07c16ee

File tree

5 files changed

+117
-96
lines changed

5 files changed

+117
-96
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,5 @@ Optimisers.trainable
4848
Optimisers.apply!
4949
Optimisers.init
5050
Optimisers.@..
51+
Optimisers.@lazy
5152
```

src/interface.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,34 +78,42 @@ end
7878

7979
"""
8080
@.. x = x + y
81-
@.. x + y / z
8281
83-
Magic broadcasting macro, for use in `apply!` rules:
84-
* Applied to assignment `x = ...` it is like `@.` unless `!iswriteable(x)`,
85-
in which case it ignores `x`, and applies `@.` on the right.
86-
* Applied to other expressions, it broadcasts like `@.` but does not materialise,
87-
returning a `Broadcasted` object for later use.
82+
Sometimes in-place broadcasting macro, for use in `apply!` rules.
83+
If `iswriteable(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
8884
"""
8985
macro var".."(ex)
90-
if Meta.isexpr(ex, :(=))
91-
dst = esc(ex.args[1])
92-
src = esc(Broadcast.__dot__(ex.args[2]))
93-
:(if $iswriteable($dst)
94-
$dst .= $src
95-
else
96-
$src
97-
end)
98-
else
99-
bc = esc(Broadcast.__dot__(ex))
100-
:($lazy.($bc))
101-
end
86+
Meta.isexpr(ex, :(=)) || throw("the macro @.. only accepts assignment, like @.. x = y + z")
87+
dst = esc(ex.args[1])
88+
src = esc(Broadcast.__dot__(ex.args[2]))
89+
:($dst = if $iswriteable($dst)
90+
$dst .= $src
91+
else
92+
$src
93+
end)
94+
end
95+
96+
"""
97+
x = @lazy y + z
98+
99+
Lazy broadcasting macro, for use in `apply!` rules. It broadcasts like `@.`
100+
but does not materialise, returning a `Broadcasted` object for later use.
101+
Beware that mutation of arguments will affect the result,
102+
and that if it is used in two places, work will be done twice.
103+
"""
104+
macro lazy(ex)
105+
bc = esc(Broadcast.__dot__(ex))
106+
:($lazy.($bc))
102107
end
103108

104109
function lazy end
105110
Broadcast.broadcasted(::typeof(lazy), x) = Lazy(x)
106111
struct Lazy{T}; bc::T; end
107112
Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
108113

114+
onevalue::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
115+
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)
116+
109117
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
110118
ioc = IOContext(io, :compact => true)
111119
print(ioc, "Leaf(", ℓ.rule, ", ")

src/rules.jl

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ init(o::Descent, x::AbstractArray) = nothing
1818
function apply!(o::Descent, state, x, dx)
1919
η = convert(float(eltype(x)), o.eta)
2020

21-
return state, @.. dx * η
21+
return state, @lazy dx * η # @lazy creates a Broadcasted, will later fuse with x .= x .- dx
2222
end
2323

2424
"""
@@ -41,10 +41,10 @@ Momentum(η = 1f-2, ρ = 9f-1) = Momentum{typeof(η)}(η, ρ)
4141
init(o::Momentum, x::AbstractArray) = zero(x)
4242

4343
function apply!(o::Momentum, state, x, dx)
44-
η, ρ, v = o.eta, o.rho, state
45-
v′ = @.. v = ρ * v - η * dx
44+
η, ρ, mvel = o.eta, o.rho, state
45+
@.. mvel = ρ * mvel + η * dx # Macro @.. broadcasts into mvel if it can, else @. of rhs.
4646

47-
return v′, @.. -v′
47+
return mvel, mvel
4848
end
4949

5050
"""
@@ -67,11 +67,12 @@ Nesterov(η = 1f-3, ρ = 9f-1) = Nesterov{typeof(η)}(η, ρ)
6767
init(o::Nesterov, x::AbstractArray) = zero(x)
6868

6969
function apply!(o::Nesterov, state, x, dx)
70-
η, ρ, v = o.eta, o.rho, state
71-
d = @.. ρ^2 * v - (1+ρ) * η * dx
72-
v′ = @.. v = ρ * v - η * dx
70+
η, ρ, vel = o.eta, o.rho, state
71+
72+
newdx = @. - ρ^2 * vel + (1+ρ) * η * dx # Cannot be lazy as this needs the old velocity
73+
@.. vel = ρ * vel - η * dx
7374

74-
return v′, @.. -d
75+
return vel, newdx
7576
end
7677

7778
"""
@@ -101,10 +102,11 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
101102

102103
function apply!(o::RMSProp, state, x, dx)
103104
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
104-
acc′ = @.. acc = ρ * acc + (1 - ρ) * dx^2
105-
dx′ = @.. dx */ (sqrt(acc) + ϵ))
105+
106+
@.. acc = ρ * acc + (1 - ρ) * dx^2
107+
dx′ = @lazy dx */ (sqrt(acc) + ϵ))
106108

107-
return acc, dx′
109+
return acc, dx′
108110
end
109111

110112
"""
@@ -129,15 +131,15 @@ ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = ADAM{typeof(η)}(
129131

130132
init(o::ADAM, x::AbstractArray) = (zero(x), zero(x), o.beta)
131133

132-
function apply!(o::ADAM{T}, state, x, dx) where T
134+
function apply!(o::ADAM, state, x, dx)
133135
η, β, ϵ = o.eta, o.beta, o.epsilon
134136
mt, vt, βt = state
135137

136-
mt′ = @.. mt = β[1] * mt + (one(T) - β[1]) * dx
137-
vt′ = @.. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2
138-
dx′ = @.. mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η
138+
@.. mt = β[1] * mt + (1 - β[1]) * dx
139+
@.. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
140+
dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η
139141

140-
return (mt, vt, βt .* β), dx′
142+
return (mt, vt, βt .* β), dx′
141143
end
142144

143145
"""
@@ -168,17 +170,17 @@ function apply!(o::RADAM, state, x, dx)
168170

169171
mt, vt, βt, t = state
170172

171-
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
172-
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2
173+
@.. mt = β[1] * mt + (1 - β[1]) * dx
174+
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
173175
ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
174176
if ρ > 4
175177
r = sqrt((ρ - 4) *- 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
176-
dx′ = @.. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
178+
dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
177179
else
178-
dx′ = @.. mt / (1 - βt[1]) * η
180+
dx′ = @lazy mt / (1 - βt[1]) * η
179181
end
180182

181-
return (mt, vt, βt .* β, t + 1), dx′
183+
return (mt, vt, βt .* β, t + 1), dx′
182184
end
183185

184186
"""
@@ -205,14 +207,13 @@ init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta)
205207

206208
function apply!(o::AdaMax, state, x, dx)
207209
η, β, ϵ = o.eta, o.beta, o.epsilon
208-
209210
mt, ut, βt = state
210211

211-
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
212-
ut′ = @.. ut = max(β[2] * ut, abs(dx))
213-
dx′ = @../(1 - βt[1])) * mt/(ut + ϵ)
212+
@.. mt = β[1] * mt + (1 - β[1]) * dx
213+
@.. ut = max(β[2] * ut, abs(dx))
214+
dx′ = @lazy/(1 - βt[1])) * mt/(ut + ϵ)
214215

215-
return (mt, ut, βt .* β), dx′
216+
return (mt, ut, βt .* β), dx′
216217
end
217218

218219
"""
@@ -240,16 +241,15 @@ init(o::OADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))
240241

241242
function apply!(o::OADAM, state, x, dx)
242243
η, β, ϵ = o.eta, o.beta, o.epsilon
244+
mt, vt, βt, term = state
243245

244-
mt, vt, βt, dx_ = state
246+
@.. mt = β[1] * mt + (1 - β[1]) * dx
247+
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
248+
prev = copy(term)
249+
@.. term = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
250+
dx′ = @lazy 2 * term - prev
245251

246-
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
247-
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2
248-
dx = @.. -dx_
249-
dx_′ = @.. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
250-
dx′ = @.. dx + 2*dx_
251-
252-
return (mt′, vt′, βt .* β, dx_′), dx′
252+
return (mt, vt, βt .* β, term), dx′
253253
end
254254

255255
"""
@@ -271,16 +271,16 @@ struct ADAGrad{T}
271271
end
272272
ADAGrad= 1f-1, ϵ = eps(typeof(η))) = ADAGrad{typeof(η)}(η, ϵ)
273273

274-
init(o::ADAGrad, x::AbstractArray) = fill!(similar(x), o.epsilon)
274+
init(o::ADAGrad, x::AbstractArray) = onevalue(o.epsilon, x)
275275

276276
function apply!(o::ADAGrad, state, x, dx)
277277
η, ϵ = o.eta, o.epsilon
278278
acc = state
279279

280-
acc′ = @.. acc = acc + dx^2
281-
dx′ = @.. dx * η / (sqrt(acc) + ϵ)
280+
@.. acc = acc + dx^2
281+
dx′ = @lazy dx * η / (sqrt(acc) + ϵ)
282282

283-
return acc, dx′
283+
return acc, dx′
284284
end
285285

286286
"""
@@ -307,13 +307,12 @@ function apply!(o::ADADelta, state, x, dx)
307307
ρ, ϵ = o.rho, o.epsilon
308308
acc, Δacc = state
309309

310-
acc′ = @.. acc = ρ * acc + (1 - ρ) * dx^2
311-
# DON'T remove epsilon from numerator
312-
# or even out of the square roots
313-
dx′ = @.. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
314-
Δacc′ = @.. Δacc = ρ * Δacc + (1 - ρ) * dx^2
310+
@.. acc = ρ * acc + (1 - ρ) * dx^2
311+
# DON'T remove epsilon from numerator or even out of the square roots!
312+
dx′ = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ) # Cannot be lazy as this needs the old Δacc
313+
@.. Δacc = ρ * Δacc + (1 - ρ) * dx′^2
315314

316-
return (acc, Δacc), dx′
315+
return (acc, Δacc), dx′
317316
end
318317

319318
"""
@@ -338,19 +337,18 @@ end
338337
AMSGrad= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AMSGrad{typeof(η)}(η, β, ϵ)
339338

340339
init(o::AMSGrad, x::AbstractArray) =
341-
(fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon))
340+
(onevalue(o.epsilon, x), onevalue(o.epsilon, x), onevalue(o.epsilon, x))
342341

343342
function apply!(o::AMSGrad, state, x, dx)
344343
η, β, ϵ = o.eta, o.beta, o.epsilon
345-
346344
mt, vt, v̂t = state
347345

348-
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
349-
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
350-
v̂t′ = @.. v̂t = max(v̂t, vt)
351-
dx′ = @.. η * mt / (sqrt(v̂t) + ϵ)
346+
@.. mt = β[1] * mt + (1 - β[1]) * dx
347+
@.. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
348+
@.. v̂t = max(v̂t, vt)
349+
dx′ = @lazy η * mt / (sqrt(v̂t) + ϵ)
352350

353-
return (mt, vt, v̂t), dx′
351+
return (mt, vt, v̂t), dx′
354352
end
355353

356354
"""
@@ -381,12 +379,12 @@ function apply!(o::NADAM, state, x, dx)
381379

382380
mt, vt, βt = state
383381

384-
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
385-
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2
386-
dx′ = @.. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
382+
@.. mt = β[1] * mt + (1 - β[1]) * dx
383+
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
384+
dx′ = @lazy (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
387385
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η
388386

389-
return (mt, vt, βt .* β), dx′
387+
return (mt, vt, βt .* β), dx′
390388
end
391389

392390
"""
@@ -405,7 +403,7 @@ weight decay regularization.
405403
(no need to change default)
406404
"""
407405
ADAMW= 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) =
408-
OptimiserChain(ADAM{typeof(η)}(η, β, ϵ), WeightDecay(γ))
406+
OptimiserChain(ADAM{typeof(η)}(η, β, ϵ), WeightDecay{typeof(η)}(γ))
409407

410408
"""
411409
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
@@ -434,11 +432,11 @@ function apply!(o::AdaBelief, state, x, dx)
434432
η, β, ϵ = o.eta, o.beta, o.epsilon
435433
mt, st = state
436434

437-
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
438-
st′ = @.. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
439-
dx′ = @.. η * mt / (sqrt(st) + ϵ)
435+
@.. mt = β[1] * mt + (1 - β[1]) * dx
436+
@.. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
437+
dx′ = @lazy η * mt / (sqrt(st) + ϵ)
440438

441-
return (mt, st), dx′
439+
return (mt, st), dx′
442440
end
443441

444442
"""
@@ -457,7 +455,7 @@ WeightDecay() = WeightDecay(5f-4)
457455
init(o::WeightDecay, x::AbstractArray) = nothing
458456

459457
function apply!(o::WeightDecay, state, x, dx)
460-
dx′ = @.. dx + o.wd * x
458+
dx′ = @lazy dx + o.wd * x
461459

462460
return state, dx′
463461
end
@@ -478,7 +476,7 @@ init(o::ClipGrad, x::AbstractArray) = nothing
478476

479477
function apply!(o::ClipGrad, state, x, dx)
480478
δ = convert(float(eltype(x)), o.delta)
481-
dx′ = @.. clamp(dx, -δ, δ)
479+
dx′ = @lazy clamp(dx, -δ, δ)
482480

483481
return state, dx′
484482
end
@@ -510,7 +508,7 @@ function apply!(o::ClipNorm, state, x, dx)
510508
end
511509
λ = min(o.omega / nrm, 1)
512510

513-
return state, @.. dx * λ
511+
return state, @lazy dx * λ
514512
end
515513

516514
"""

0 commit comments

Comments
 (0)