|
| 1 | +""" |
| 2 | +# Adam |
| 3 | +## Constructor |
| 4 | +```julia |
| 5 | + Adam(; alpha=0.0001, beta_mean=0.9, beta_var=0.999, epsilon=1e-8) |
| 6 | +``` |
| 7 | +## Description |
| 8 | +Adam is a gradient based optimizer that choses its search direction by building up estimates of the first two moments of the gradient vector. This makes it suitable for problems with a stochastic objective and thus gradient. The method is introduced in [1] where the related AdaMax method is also introduced, see `?AdaMax` for more information on that method. |
| 9 | +
|
| 10 | +## References |
| 11 | +[1] https://arxiv.org/abs/1412.6980 |
| 12 | +""" |
| 13 | +struct Adam{T, Tm} <: FirstOrderOptimizer |
| 14 | + α::T |
| 15 | + β₁::T |
| 16 | + β₂::T |
| 17 | + ϵ::T |
| 18 | + manifold::Tm |
| 19 | +end |
| 20 | +Adam(; alpha = 0.0001, beta_mean = 0.9, beta_var = 0.999, epsilon = 1e-8) = |
| 21 | + Adam(alpha, beta_mean, beta_var, epsilon, Flat()) |
| 22 | +Base.summary(::Adam) = "Adam" |
| 23 | +function default_options(method::Adam) |
| 24 | + (; allow_f_increases = true, iterations=10_000) |
| 25 | +end |
| 26 | + |
| 27 | +mutable struct AdamState{Tx, T, Tz, Tm, Tu, Ti} <: AbstractOptimizerState |
| 28 | + x::Tx |
| 29 | + x_previous::Tx |
| 30 | + f_x_previous::T |
| 31 | + s::Tx |
| 32 | + z::Tz |
| 33 | + m::Tm |
| 34 | + u::Tu |
| 35 | + iter::Ti |
| 36 | +end |
| 37 | +function reset!(method, state::AdamState, obj, x) |
| 38 | + value_gradient!!(obj, x) |
| 39 | +end |
| 40 | +function initial_state(method::Adam, options, d, initial_x::AbstractArray{T}) where T |
| 41 | + initial_x = copy(initial_x) |
| 42 | + |
| 43 | + value_gradient!!(d, initial_x) |
| 44 | + α, β₁, β₂ = method.α, method.β₁, method.β₂ |
| 45 | + |
| 46 | + z = copy(initial_x) |
| 47 | + m = copy(gradient(d)) |
| 48 | + u = fill(zero(m[1]^2), length(m)) |
| 49 | + a = 1 - β₁ |
| 50 | + iter = 0 |
| 51 | + |
| 52 | + AdamState(initial_x, # Maintain current state in state.x |
| 53 | + copy(initial_x), # Maintain previous state in state.x_previous |
| 54 | + real(T(NaN)), # Store previous f in state.f_x_previous |
| 55 | + similar(initial_x), # Maintain current search direction in state.s |
| 56 | + z, |
| 57 | + m, |
| 58 | + u, |
| 59 | + iter) |
| 60 | +end |
| 61 | + |
| 62 | +function update_state!(d, state::AdamState{T}, method::Adam) where T |
| 63 | + state.iter = state.iter+1 |
| 64 | + value_gradient!(d, state.x) |
| 65 | + α, β₁, β₂, ϵ = method.α, method.β₁, method.β₂, method.ϵ |
| 66 | + a = 1 - β₁ |
| 67 | + b = 1 - β₂ |
| 68 | + |
| 69 | + m, u, z = state.m, state.u, state.z |
| 70 | + v = u |
| 71 | + m .= β₁ .* m .+ a .* gradient(d) |
| 72 | + v .= β₂ .* v .+ b .* gradient(d) .^ 2 |
| 73 | + # m̂ = m./(1-β₁^state.iter) |
| 74 | + # v̂ = v./(1-β₂^state.iter) |
| 75 | + #@. z = z - α*m̂/(sqrt(v̂+ϵ)) |
| 76 | + @. z = z - α*m/(1-β₁^state.iter)/(sqrt(v./(1-β₂^state.iter)+ϵ)) |
| 77 | + |
| 78 | + # not quite the same because epsilon is in the sqrt |
| 79 | + # not sure where I got this from |
| 80 | + # αₜ = α * sqrt(1 - β₂^state.iter) / (1 - β₁^state.iter) |
| 81 | + # z .= z .- αₜ .* m ./ (sqrt.(v .+ ϵ) ) |
| 82 | + |
| 83 | + for _i in eachindex(z) |
| 84 | + # since m and u start at 0, this can happen if the initial gradient is exactly 0 |
| 85 | + # rosenbrock(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2 |
| 86 | + # optimize(rosenbrock, zeros(2), Adam(), Optim.Options(iterations=10000)) |
| 87 | + if isnan(z[_i]) |
| 88 | + z[_i] = state.x[_i] |
| 89 | + end |
| 90 | + end |
| 91 | + state.x .= z |
| 92 | + # Update current position # x = x + alpha * s |
| 93 | + false # break on linesearch error |
| 94 | +end |
| 95 | + |
| 96 | +function trace!(tr, d, state, iteration, method::Adam, options, curr_time=time()) |
| 97 | + common_trace!(tr, d, state, iteration, method, options, curr_time) |
| 98 | +end |
0 commit comments