Skip to content

Commit 3816725

Browse files
author
cossio
committed
AdaBelief bias correction
1 parent 25457f5 commit 3816725

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/optimise/optimisers.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,16 @@ AdaBelief(η::Real, β::Tuple, state::IdDict) = AdaBelief(η, β, EPS, state)
532532

533533
function apply!(o::AdaBelief, x, Δ)
534534
η, β = o.eta, o.beta
535-
mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
535+
536+
mt, st, βp = get!(o.state, x) do
537+
(zero(x), zero(x), Float64[β[1], β[2]])
538+
end :: Tuple{typeof(x), typeof(x), Vector{Float64}}
539+
536540
@. mt = β[1] * mt + (1 - β[1]) * Δ
537541
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt)
538-
@. Δ = η * mt / ((st) + o.epsilon)
542+
@. Δ = η * mt / (1 - βp[1]) / ((st / (1 - βp[2])) + o.epsilon)
543+
βp .= βp .* β
544+
539545
return Δ
540546
end
541547

0 commit comments

Comments
 (0)