Skip to content

Commit eaa7ee8

Browse files
authored
Merge pull request #1963 from cossio/bias
AdaBelief bias correction
2 parents 25457f5 + 6403805 commit eaa7ee8

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.13.0"
3+
version = "0.13.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/optimise/optimisers.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,26 @@ AdaBelief(η::Real, β::Tuple, state::IdDict) = AdaBelief(η, β, EPS, state)
532532

533533
function apply!(o::AdaBelief, x, Δ)
534534
η, β = o.eta, o.beta
535-
mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
535+
536+
mt, st, βp = get!(o.state, x) do
537+
(zero(x), zero(x), Float64[β[1], β[2]])
538+
end :: Tuple{typeof(x), typeof(x), Vector{Float64}}
539+
540+
#= st is a variance and can go to zero. This is in contrast to ADAM, which uses the
541+
second moment which is usually far enough from zero. This is problematic, since st
542+
can be slightly negative due to numerical error, and the square root below will fail.
543+
Also, if we want to differentiate through the optimizer, √0 is not differentiable.
544+
To protect against this, we add a small number, st -> st + eps2.
545+
The original implementation (https://github.com/juntang-zhuang/Adabelief-Optimizer)
546+
uses the square of Adam's epsilon, which we do here.
547+
See also: https://github.com/juntang-zhuang/Adabelief-Optimizer/issues/61 =#
548+
eps2 = o.epsilon^2 # TODO: make epsilon^2 the default in next breaking release
549+
536550
@. mt = β[1] * mt + (1 - β[1]) * Δ
537-
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt)
538-
@. Δ = η * mt / ((st) + o.epsilon)
551+
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt) + eps2
552+
@. Δ = η * mt / (1 - βp[1]) / ((st / (1 - βp[2])) + eps2)
553+
βp .= βp .* β
554+
539555
return Δ
540556
end
541557

0 commit comments

Comments
 (0)