@@ -406,7 +406,7 @@ ADAMW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) =
406
406
OptimiserChain (ADAM {typeof(η)} (η, β, ϵ), WeightDecay {typeof(η)} (γ))
407
407
408
408
"""
409
- AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)) )
409
+ AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = 1e-16 )
410
410
411
411
The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known
412
412
ADAM optimiser.
@@ -424,19 +424,19 @@ struct AdaBelief{T}
424
424
beta:: Tuple{T, T}
425
425
epsilon:: T
426
426
end
427
- AdaBelief (η = 1f-3 , β = (9f-1 , 9.99f-1 ), ϵ = eps ( typeof (η) )) = AdaBelief {typeof(η)} (η, β, ϵ)
427
+ AdaBelief (η = 1f-3 , β = (9f-1 , 9.99f-1 ), ϵ = oftype (η, 1e-16 )) = AdaBelief {typeof(η)} (η, β, ϵ)
428
428
429
- init (o:: AdaBelief , x:: AbstractArray ) = (zero (x), zero (x))
429
+ init (o:: AdaBelief , x:: AbstractArray ) = (zero (x), zero (x), o . beta )
430
430
431
431
function apply! (o:: AdaBelief , state, x, dx)
432
432
η, β, ϵ = o. eta, o. beta, o. epsilon
433
- mt, st = state
433
+ mt, st, βt = state
434
434
435
435
@. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
436
- @. . st = β[2 ] * st + (1 - β[2 ]) * abs2 (dx - mt)
437
- dx′ = @lazy η * mt / (sqrt (st) + ϵ)
436
+ @. . st = β[2 ] * st + (1 - β[2 ]) * abs2 (dx - mt) + ϵ
437
+ dx′ = @lazy η * mt / (1 - βt[ 1 ]) / ( sqrt (st / ( 1 - βt[ 2 ]) ) + ϵ)
438
438
439
- return (mt, st), dx′
439
+ return (mt, st, βt .* β ), dx′
440
440
end
441
441
442
442
"""
0 commit comments