@@ -18,7 +18,7 @@ init(o::Descent, x::AbstractArray) = nothing
18
18
function apply! (o:: Descent , state, x, dx)
19
19
η = convert (float (eltype (x)), o. eta)
20
20
21
- return state, @. . dx * η
21
+ return state, @lazy dx * η # @lazy creates a Broadcasted, will later fuse with x .= x .- dx
22
22
end
23
23
24
24
"""
@@ -41,10 +41,10 @@ Momentum(η = 1f-2, ρ = 9f-1) = Momentum{typeof(η)}(η, ρ)
41
41
init (o:: Momentum , x:: AbstractArray ) = zero (x)
42
42
43
43
function apply! (o:: Momentum , state, x, dx)
44
- η, ρ, v = o. eta, o. rho, state
45
- v′ = @. . v = ρ * v - η * dx
44
+ η, ρ, mvel = o. eta, o. rho, state
45
+ @. . mvel = ρ * mvel + η * dx # Macro @.. broadcasts into mvel if it can, else @. of rhs.
46
46
47
- return v′, @. . - v′
47
+ return mvel, mvel
48
48
end
49
49
50
50
"""
@@ -67,11 +67,12 @@ Nesterov(η = 1f-3, ρ = 9f-1) = Nesterov{typeof(η)}(η, ρ)
67
67
init (o:: Nesterov , x:: AbstractArray ) = zero (x)
68
68
69
69
function apply! (o:: Nesterov , state, x, dx)
70
- η, ρ, v = o. eta, o. rho, state
71
- d = @. . ρ^ 2 * v - (1 + ρ) * η * dx
72
- v′ = @. . v = ρ * v - η * dx
70
+ η, ρ, vel = o. eta, o. rho, state
71
+
72
+ newdx = @. - ρ^ 2 * vel + (1 + ρ) * η * dx # Cannot be lazy as this needs the old velocity
73
+ @. . vel = ρ * vel - η * dx
73
74
74
- return v′, @. . - d
75
+ return vel, newdx
75
76
end
76
77
77
78
"""
@@ -101,10 +102,11 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
101
102
102
103
function apply! (o:: RMSProp , state, x, dx)
103
104
η, ρ, ϵ, acc = o. eta, o. rho, o. epsilon, state
104
- acc′ = @. . acc = ρ * acc + (1 - ρ) * dx^ 2
105
- dx′ = @. . dx * (η / (sqrt (acc) + ϵ))
105
+
106
+ @. . acc = ρ * acc + (1 - ρ) * dx^ 2
107
+ dx′ = @lazy dx * (η / (sqrt (acc) + ϵ))
106
108
107
- return acc′ , dx′
109
+ return acc, dx′
108
110
end
109
111
110
112
"""
@@ -129,15 +131,15 @@ ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = ADAM{typeof(η)}(
129
131
130
132
init (o:: ADAM , x:: AbstractArray ) = (zero (x), zero (x), o. beta)
131
133
132
- function apply! (o:: ADAM{T} , state, x, dx) where T
134
+ function apply! (o:: ADAM , state, x, dx)
133
135
η, β, ϵ = o. eta, o. beta, o. epsilon
134
136
mt, vt, βt = state
135
137
136
- mt′ = @. . mt = β[1 ] * mt + (one (T) - β[1 ]) * dx
137
- vt′ = @. . vt = β[2 ] * vt + (one (T) - β[2 ]) * dx ^ 2
138
- dx′ = @. . mt / (one (T) - βt[1 ]) / (sqrt (vt / (one (T) - βt[2 ])) + ϵ) * η
138
+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
139
+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx ^ 2
140
+ dx′ = @lazy mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η
139
141
140
- return (mt′ , vt′ , βt .* β), dx′
142
+ return (mt, vt, βt .* β), dx′
141
143
end
142
144
143
145
"""
@@ -168,17 +170,17 @@ function apply!(o::RADAM, state, x, dx)
168
170
169
171
mt, vt, βt, t = state
170
172
171
- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
172
- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
173
+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
174
+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
173
175
ρ = ρ∞ - 2 * t * βt[2 ] / (1 - βt[2 ])
174
176
if ρ > 4
175
177
r = sqrt ((ρ - 4 ) * (ρ - 2 ) * ρ∞/ ((ρ∞ - 4 ) * (ρ∞ - 2 ) * ρ))
176
- dx′ = @. . mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η * r
178
+ dx′ = @lazy mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η * r
177
179
else
178
- dx′ = @. . mt / (1 - βt[1 ]) * η
180
+ dx′ = @lazy mt / (1 - βt[1 ]) * η
179
181
end
180
182
181
- return (mt′ , vt′ , βt .* β, t + 1 ), dx′
183
+ return (mt, vt, βt .* β, t + 1 ), dx′
182
184
end
183
185
184
186
"""
@@ -205,14 +207,13 @@ init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta)
205
207
206
208
function apply! (o:: AdaMax , state, x, dx)
207
209
η, β, ϵ = o. eta, o. beta, o. epsilon
208
-
209
210
mt, ut, βt = state
210
211
211
- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
212
- ut′ = @. . ut = max (β[2 ] * ut, abs (dx))
213
- dx′ = @. . (η/ (1 - βt[1 ])) * mt/ (ut + ϵ)
212
+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
213
+ @. . ut = max (β[2 ] * ut, abs (dx))
214
+ dx′ = @lazy (η/ (1 - βt[1 ])) * mt/ (ut + ϵ)
214
215
215
- return (mt′ , ut′ , βt .* β), dx′
216
+ return (mt, ut, βt .* β), dx′
216
217
end
217
218
218
219
"""
@@ -240,16 +241,15 @@ init(o::OADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))
240
241
241
242
function apply! (o:: OADAM , state, x, dx)
242
243
η, β, ϵ = o. eta, o. beta, o. epsilon
244
+ mt, vt, βt, term = state
243
245
244
- mt, vt, βt, dx_ = state
246
+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
247
+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
248
+ prev = copy (term)
249
+ @. . term = η * mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ)
250
+ dx′ = @lazy 2 * term - prev
245
251
246
- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
247
- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
248
- dx = @. . - dx_
249
- dx_′ = @. . dx_ = η * mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ)
250
- dx′ = @. . dx + 2 * dx_
251
-
252
- return (mt′, vt′, βt .* β, dx_′), dx′
252
+ return (mt, vt, βt .* β, term), dx′
253
253
end
254
254
255
255
"""
@@ -271,16 +271,16 @@ struct ADAGrad{T}
271
271
end
272
272
ADAGrad (η = 1f-1 , ϵ = eps (typeof (η))) = ADAGrad {typeof(η)} (η, ϵ)
273
273
274
- init (o:: ADAGrad , x:: AbstractArray ) = fill! ( similar (x), o. epsilon)
274
+ init (o:: ADAGrad , x:: AbstractArray ) = onevalue ( o. epsilon, x )
275
275
276
276
function apply! (o:: ADAGrad , state, x, dx)
277
277
η, ϵ = o. eta, o. epsilon
278
278
acc = state
279
279
280
- acc′ = @. . acc = acc + dx^ 2
281
- dx′ = @. . dx * η / (sqrt (acc) + ϵ)
280
+ @. . acc = acc + dx^ 2
281
+ dx′ = @lazy dx * η / (sqrt (acc) + ϵ)
282
282
283
- return acc′ , dx′
283
+ return acc, dx′
284
284
end
285
285
286
286
"""
@@ -307,13 +307,12 @@ function apply!(o::ADADelta, state, x, dx)
307
307
ρ, ϵ = o. rho, o. epsilon
308
308
acc, Δacc = state
309
309
310
- acc′ = @. . acc = ρ * acc + (1 - ρ) * dx^ 2
311
- # DON'T remove epsilon from numerator
312
- # or even out of the square roots
313
- dx′ = @. . dx * sqrt (Δacc + ϵ) / sqrt (acc + ϵ)
314
- Δacc′ = @. . Δacc = ρ * Δacc + (1 - ρ) * dx^ 2
310
+ @. . acc = ρ * acc + (1 - ρ) * dx^ 2
311
+ # DON'T remove epsilon from numerator or even out of the square roots!
312
+ dx′ = @. dx * sqrt (Δacc + ϵ) / sqrt (acc + ϵ) # Cannot be lazy as this needs the old Δacc
313
+ @. . Δacc = ρ * Δacc + (1 - ρ) * dx′^ 2
315
314
316
- return (acc′ , Δacc′ ), dx′
315
+ return (acc, Δacc), dx′
317
316
end
318
317
319
318
"""
@@ -338,19 +337,18 @@ end
338
337
AMSGrad (η = 1f-3 , β = (9f-1 , 9.99f-1 ), ϵ = eps (typeof (η))) = AMSGrad {typeof(η)} (η, β, ϵ)
339
338
340
339
init (o:: AMSGrad , x:: AbstractArray ) =
341
- (fill! ( similar (x), o. epsilon), fill! ( similar ( x), o. epsilon), fill! ( similar ( x), o. epsilon))
340
+ (onevalue ( o. epsilon, x), onevalue ( o. epsilon, x), onevalue ( o. epsilon, x ))
342
341
343
342
function apply! (o:: AMSGrad , state, x, dx)
344
343
η, β, ϵ = o. eta, o. beta, o. epsilon
345
-
346
344
mt, vt, v̂t = state
347
345
348
- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
349
- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx ^ 2
350
- v̂t′ = @. . v̂t = max (v̂t, vt)
351
- dx′ = @. . η * mt / (sqrt (v̂t) + ϵ)
346
+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
347
+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx ^ 2
348
+ @. . v̂t = max (v̂t, vt)
349
+ dx′ = @lazy η * mt / (sqrt (v̂t) + ϵ)
352
350
353
- return (mt′ , vt′ , v̂t′ ), dx′
351
+ return (mt, vt, v̂t), dx′
354
352
end
355
353
356
354
"""
@@ -381,12 +379,12 @@ function apply!(o::NADAM, state, x, dx)
381
379
382
380
mt, vt, βt = state
383
381
384
- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
385
- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
386
- dx′ = @. . (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
382
+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
383
+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
384
+ dx′ = @lazy (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
387
385
(sqrt (vt * β[2 ] / (1 - βt[2 ])) + ϵ) * η
388
386
389
- return (mt′ , vt′ , βt .* β), dx′
387
+ return (mt, vt, βt .* β), dx′
390
388
end
391
389
392
390
"""
@@ -405,7 +403,7 @@ weight decay regularization.
405
403
(no need to change default)
406
404
"""
407
405
ADAMW (η = 1f-3 , β = (9f-1 , 9.99f-1 ), γ = 0 , ϵ = eps (typeof (η))) =
408
- OptimiserChain (ADAM {typeof(η)} (η, β, ϵ), WeightDecay (γ))
406
+ OptimiserChain (ADAM {typeof(η)} (η, β, ϵ), WeightDecay {typeof(η)} (γ))
409
407
410
408
"""
411
409
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
@@ -434,11 +432,11 @@ function apply!(o::AdaBelief, state, x, dx)
434
432
η, β, ϵ = o. eta, o. beta, o. epsilon
435
433
mt, st = state
436
434
437
- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
438
- st′ = @. . st = β[2 ] * st + (1 - β[2 ]) * (dx - mt)^ 2
439
- dx′ = @. . η * mt / (sqrt (st) + ϵ)
435
+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
436
+ @. . st = β[2 ] * st + (1 - β[2 ]) * (dx - mt)^ 2
437
+ dx′ = @lazy η * mt / (sqrt (st) + ϵ)
440
438
441
- return (mt′ , st′ ), dx′
439
+ return (mt, st), dx′
442
440
end
443
441
444
442
"""
@@ -457,7 +455,7 @@ WeightDecay() = WeightDecay(5f-4)
457
455
init (o:: WeightDecay , x:: AbstractArray ) = nothing
458
456
459
457
function apply! (o:: WeightDecay , state, x, dx)
460
- dx′ = @. . dx + o. wd * x
458
+ dx′ = @lazy dx + o. wd * x
461
459
462
460
return state, dx′
463
461
end
@@ -478,7 +476,7 @@ init(o::ClipGrad, x::AbstractArray) = nothing
478
476
479
477
function apply! (o:: ClipGrad , state, x, dx)
480
478
δ = convert (float (eltype (x)), o. delta)
481
- dx′ = @. . clamp (dx, - δ, δ)
479
+ dx′ = @lazy clamp (dx, - δ, δ)
482
480
483
481
return state, dx′
484
482
end
@@ -510,7 +508,7 @@ function apply!(o::ClipNorm, state, x, dx)
510
508
end
511
509
λ = min (o. omega / nrm, 1 )
512
510
513
- return state, @. . dx * λ
511
+ return state, @lazy dx * λ
514
512
end
515
513
516
514
"""
0 commit comments