@@ -108,9 +108,9 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
108
108
function apply (o:: RMSProp , state, x, dx)
109
109
η, ρ, ϵ, acc = o. eta, o. rho, o. epsilon, state
110
110
@. acc = ρ * acc + (1 - ρ) * dx^ 2
111
- dx = @. dx * (η / (sqrt (acc) + ϵ))
111
+ dx′ = @. dx * (η / (sqrt (acc) + ϵ))
112
112
113
- return acc, dx
113
+ return acc, dx′
114
114
end
115
115
116
116
(o:: RMSProp )(state, m, dm) = update (o, state, m, dm)
@@ -145,9 +145,9 @@ function apply(o::ADAM{T}, state, x, dx) where T
145
145
146
146
@. mt = β[1 ] * mt + (one (T) - β[1 ]) * dx
147
147
@. vt = β[2 ] * vt + (one (T) - β[2 ]) * dx ^ 2
148
- dx = @. mt / (one (T) - βt[1 ]) / (sqrt (vt / (one (T) - βt[2 ])) + ϵ) * η
148
+ dx′ = @. mt / (one (T) - βt[1 ]) / (sqrt (vt / (one (T) - βt[2 ])) + ϵ) * η
149
149
150
- return (mt, vt, βt .* β), dx
150
+ return (mt, vt, βt .* β), dx′
151
151
end
152
152
153
153
"""
@@ -185,12 +185,12 @@ function apply(o::RADAM, state, x, dx)
185
185
ρ = ρ∞ - 2 * t * βt[2 ] / (1 - βt[2 ])
186
186
if ρ > 4
187
187
r = sqrt ((ρ - 4 ) * (ρ - 2 ) * ρ∞/ ((ρ∞ - 4 ) * (ρ∞ - 2 ) * ρ))
188
- dx = @. mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η * r
188
+ dx′ = @. mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η * r
189
189
else
190
- dx = @. mt / (1 - βt[1 ]) * η
190
+ dx′ = @. mt / (1 - βt[1 ]) * η
191
191
end
192
192
193
- return (mt, vt, βt .* β, t + 1 ), dx
193
+ return (mt, vt, βt .* β, t + 1 ), dx′
194
194
end
195
195
196
196
"""
@@ -224,9 +224,9 @@ function apply(o::AdaMax, state, x, dx)
224
224
225
225
@. mt = β[1 ] * mt + (1 - β[1 ]) * dx
226
226
@. ut = max (β[2 ] * ut, abs (dx))
227
- dx = @. (η/ (1 - βt[1 ])) * mt/ (ut + ϵ)
227
+ dx′ = @. (η/ (1 - βt[1 ])) * mt/ (ut + ϵ)
228
228
229
- return (mt, ut, βt .* β), dx
229
+ return (mt, ut, βt .* β), dx′
230
230
end
231
231
232
232
"""
@@ -263,9 +263,9 @@ function apply(o::OADAM, state, x, dx)
263
263
@. vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
264
264
@. dx = - dx_
265
265
@. dx_ = η * mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ)
266
- dx = @. dx + 2 * dx_
266
+ dx′ = @. dx + 2 * dx_
267
267
268
- return (mt, vt, βt .* β, dx_), dx
268
+ return (mt, vt, βt .* β, dx_), dx′
269
269
end
270
270
271
271
"""
@@ -296,9 +296,9 @@ function apply(o::ADAGrad, state, x, dx)
296
296
acc = state
297
297
298
298
@. acc += dx^ 2
299
- dx = @. dx * η / (sqrt (acc) + ϵ)
299
+ dx′ = @. dx * η / (sqrt (acc) + ϵ)
300
300
301
- return acc, dx
301
+ return acc, dx′
302
302
end
303
303
304
304
"""
@@ -330,10 +330,10 @@ function apply(o::ADADelta, state, x, dx)
330
330
@. acc = ρ * acc + (1 - ρ) * dx^ 2
331
331
# DON'T remove epsilon from numerator
332
332
# or even out of the square roots
333
- dx = @. dx * sqrt (Δacc + ϵ) / sqrt (acc + ϵ)
333
+ dx′ = @. dx * sqrt (Δacc + ϵ) / sqrt (acc + ϵ)
334
334
@. Δacc = ρ * Δacc + (1 - ρ) * dx^ 2
335
335
336
- return (acc, Δacc), dx
336
+ return (acc, Δacc), dx′
337
337
end
338
338
339
339
"""
@@ -370,9 +370,9 @@ function apply(o::AMSGrad, state, x, dx)
370
370
@. mt = β[1 ] * mt + (1 - β[1 ]) * dx
371
371
@. vt = β[2 ] * vt + (1 - β[2 ]) * dx ^ 2
372
372
@. v̂t = max (v̂t, vt)
373
- dx = @. η * mt / (sqrt (v̂t) + ϵ)
373
+ dx′ = @. η * mt / (sqrt (v̂t) + ϵ)
374
374
375
- return (mt, vt, v̂t), dx
375
+ return (mt, vt, v̂t), dx′
376
376
end
377
377
378
378
"""
@@ -407,10 +407,10 @@ function apply(o::NADAM, state, x, dx)
407
407
408
408
@. mt = β[1 ] * mt + (1 - β[1 ]) * dx
409
409
@. vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
410
- dx = @. (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
410
+ dx′ = @. (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
411
411
(sqrt (vt * β[2 ] / (1 - βt[2 ])) + ϵ) * η
412
412
413
- return (mt, vt, βt .* β), dx
413
+ return (mt, vt, βt .* β), dx′
414
414
end
415
415
416
416
"""
@@ -462,9 +462,9 @@ function apply(o::AdaBelief, state, x, dx)
462
462
463
463
@. mt = β[1 ] * mt + (1 - β[1 ]) * dx
464
464
@. st = β[2 ] * st + (1 - β[2 ]) * (dx - mt)^ 2
465
- dx = @. η * mt / (sqrt (st) + ϵ)
465
+ dx′ = @. η * mt / (sqrt (st) + ϵ)
466
466
467
- return (mt, st), dx
467
+ return (mt, st), dx′
468
468
end
469
469
470
470
"""
@@ -485,9 +485,9 @@ init(o::WeightDecay, x::AbstractArray) = nothing
485
485
(o:: WeightDecay )(state, m, dm) = update (o, state, m, dm)
486
486
487
487
function apply (o:: WeightDecay , state, x, dx)
488
- dx = @. dx + o. wd * x
488
+ dx′ = @. dx + o. wd * x
489
489
490
- return state, dx
490
+ return state, dx′
491
491
end
492
492
493
493
"""
0 commit comments