Skip to content

Commit 91ade31

Browse files
authored
Merge pull request #75 from cossio/bias
AdaBelief bias correction and epsilon
2 parents df5e770 + 38e3301 commit 91ade31

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/rules.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ ADAMW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) =
406406
OptimiserChain(ADAM{typeof(η)}(η, β, ϵ), WeightDecay{typeof(η)}(γ))
407407

408408
"""
409-
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
409+
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = 1e-16)
410410
411411
The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known
412412
ADAM optimiser.
@@ -424,19 +424,19 @@ struct AdaBelief{T}
424424
beta::Tuple{T, T}
425425
epsilon::T
426426
end
427-
AdaBelief= 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AdaBelief{typeof(η)}(η, β, ϵ)
427+
AdaBelief= 1f-3, β = (9f-1, 9.99f-1), ϵ = oftype(η, 1e-16)) = AdaBelief{typeof(η)}(η, β, ϵ)
428428

429-
init(o::AdaBelief, x::AbstractArray) = (zero(x), zero(x))
429+
init(o::AdaBelief, x::AbstractArray) = (zero(x), zero(x), o.beta)
430430

431431
function apply!(o::AdaBelief, state, x, dx)
432432
η, β, ϵ = o.eta, o.beta, o.epsilon
433-
mt, st = state
433+
mt, st, βt = state
434434

435435
@.. mt = β[1] * mt + (1 - β[1]) * dx
436-
@.. st = β[2] * st + (1 - β[2]) * abs2(dx - mt)
437-
dx′ = @lazy η * mt / (sqrt(st) + ϵ)
436+
@.. st = β[2] * st + (1 - β[2]) * abs2(dx - mt) + ϵ
437+
dx′ = @lazy η * mt / (1 - βt[1]) / (sqrt(st / (1 - βt[2])) + ϵ)
438438

439-
return (mt, st), dx′
439+
return (mt, st, βt .* β), dx′
440440
end
441441

442442
"""

0 commit comments

Comments
 (0)