@@ -532,10 +532,26 @@ AdaBelief(η::Real, β::Tuple, state::IdDict) = AdaBelief(η, β, EPS, state)
532
532
533
533
function apply! (o:: AdaBelief , x, Δ)
534
534
η, β = o. eta, o. beta
535
- mt, st = get! (() -> (zero (x), zero (x)), o. state, x):: NTuple{2,typeof(x)}
535
+
536
+ mt, st, βp = get! (o. state, x) do
537
+ (zero (x), zero (x), Float64[β[1 ], β[2 ]])
538
+ end :: Tuple{typeof(x), typeof(x), Vector{Float64}}
539
+
540
+ #= st is a variance and can go to zero. This is in contrast to ADAM, which uses the
541
+ second moment which is usually far enough from zero. This is problematic, since st
542
+ can be slightly negative due to numerical error, and the square root below will fail.
543
+ Also, if we want to differentiate through the optimizer, √0 is not differentiable.
544
+ To protect against this, we add a small number, st -> st + eps2.
545
+ The original implementation (https://github.com/juntang-zhuang/Adabelief-Optimizer)
546
+ uses the square of Adam's epsilon, which we do here.
547
+ See also: https://github.com/juntang-zhuang/Adabelief-Optimizer/issues/61 =#
548
+ eps2 = o. epsilon^ 2 # TODO : make epsilon^2 the default in next breaking release
549
+
536
550
@. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
537
- @. st = β[2 ] * st + (1 - β[2 ]) * (Δ - mt) * conj (Δ - mt)
538
- @. Δ = η * mt / (√ (st) + o. epsilon)
551
+ @. st = β[2 ] * st + (1 - β[2 ]) * (Δ - mt) * conj (Δ - mt) + eps2
552
+ @. Δ = η * mt / (1 - βp[1 ]) / (√ (st / (1 - βp[2 ])) + eps2)
553
+ βp .= βp .* β
554
+
539
555
return Δ
540
556
end
541
557
0 commit comments