Skip to content

Commit 31d8d96

Browse files
authored
Use zeros instead of fill in Adam (#1075)
* Use zeros instead of fill * Update adamax.jl * Apply suggestions from code review * Fix it once and for all
1 parent b0ba898 commit 31d8d96

File tree

4 files changed

+21
-43
lines changed

4 files changed

+21
-43
lines changed

docs/src/algo/adam_adamax.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ where `alpha` is the step length or learning parameter. `beta_mean` and `beta_va
1414
```julia
1515
AdaMax(; alpha=0.002,
1616
beta_mean=0.9,
17-
beta_var=0.999)
17+
beta_var=0.999,
18+
epsilon=1e-8)
1819
```
1920
where `alpha` is the step length or learning parameter. `beta_mean` and `beta_var` are exponential decay parameters for the first and second moments estimates. Setting these closer to 0 will cause past iterates to matter less for the current steps and setting them closer to 1 means emphasizing past iterates more.
2021

src/multivariate/solvers/first_order/adam.jl

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@ struct Adam{T, Tm} <: FirstOrderOptimizer
1717
ϵ::T
1818
manifold::Tm
1919
end
20+
# could use epsilon = T->sqrt(eps(T)) and input the promoted type
2021
Adam(; alpha = 0.0001, beta_mean = 0.9, beta_var = 0.999, epsilon = 1e-8) =
2122
Adam(alpha, beta_mean, beta_var, epsilon, Flat())
2223
Base.summary(::Adam) = "Adam"
2324
function default_options(method::Adam)
2425
(; allow_f_increases = true, iterations=10_000)
2526
end
2627

27-
mutable struct AdamState{Tx, T, Tz, Tm, Tu, Ti} <: AbstractOptimizerState
28+
mutable struct AdamState{Tx, T, Tm, Tu, Ti} <: AbstractOptimizerState
2829
x::Tx
2930
x_previous::Tx
3031
f_x_previous::T
3132
s::Tx
32-
z::Tz
3333
m::Tm
3434
u::Tu
3535
iter::Ti
@@ -43,17 +43,15 @@ function initial_state(method::Adam, options, d, initial_x::AbstractArray{T}) wh
4343
value_gradient!!(d, initial_x)
4444
α, β₁, β₂ = method.α, method.β₁, method.β₂
4545

46-
z = copy(initial_x)
4746
m = copy(gradient(d))
48-
u = fill(zero(m[1]^2), length(m))
47+
u = zero(m)
4948
a = 1 - β₁
5049
iter = 0
5150

5251
AdamState(initial_x, # Maintain current state in state.x
5352
copy(initial_x), # Maintain previous state in state.x_previous
5453
real(T(NaN)), # Store previous f in state.f_x_previous
5554
similar(initial_x), # Maintain current search direction in state.s
56-
z,
5755
m,
5856
u,
5957
iter)
@@ -66,25 +64,15 @@ function update_state!(d, state::AdamState{T}, method::Adam) where T
6664
a = 1 - β₁
6765
b = 1 - β₂
6866

69-
m, u, z = state.m, state.u, state.z
67+
m, u = state.m, state.u
7068
v = u
7169
m .= β₁ .* m .+ a .* gradient(d)
7270
v .= β₂ .* v .+ b .* gradient(d) .^ 2
7371
# m̂ = m./(1-β₁^state.iter)
7472
# v̂ = v./(1-β₂^state.iter)
7573
#@. z = z - α*m̂/(sqrt(v̂+ϵ))
7674
αₜ = α * sqrt(1 - β₂^state.iter) / (1 - β₁^state.iter)
77-
@. z = z - αₜ * m / (sqrt(v) + ϵ)
78-
79-
for _i in eachindex(z)
80-
# since m and u start at 0, this can happen if the initial gradient is exactly 0
81-
# rosenbrock(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2
82-
# optimize(rosenbrock, zeros(2), Adam(), Optim.Options(iterations=10000))
83-
if isnan(z[_i])
84-
z[_i] = state.x[_i]
85-
end
86-
end
87-
state.x .= z
75+
@. state.x = state.x - αₜ * m / (sqrt(v) + ϵ)
8876
# Update current position # x = x + alpha * s
8977
false # break on linesearch error
9078
end

src/multivariate/solvers/first_order/adamax.jl

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
2-
AdaMax(; alpha=0.002, beta_mean=0.9, beta_var=0.999)
3-
# Adam
2+
# AdaMax
43
## Constructor
54
```julia
6-
AdaMax(; alpha=0.002, beta_mean=0.9, beta_var=0.999)
5+
AdaMax(; alpha=0.002, beta_mean=0.9, beta_var=0.999, epsilon=1e-8)
76
```
87
## Description
98
AdaMax is a gradient based optimizer that choses its search direction by building up estimates of the first two moments of the gradient vector. This makes it suitable for problems with a stochastic objective and thus gradient. The method is introduced in [1] where the related Adam method is also introduced, see `?Adam` for more information on that method.
@@ -16,22 +15,22 @@ struct AdaMax{T,Tm} <: FirstOrderOptimizer
1615
α::T
1716
β₁::T
1817
β₂::T
18+
ϵ::T
1919
manifold::Tm
2020
end
21-
AdaMax(; alpha = 0.002, beta_mean = 0.9, beta_var = 0.999) =
22-
AdaMax(alpha, beta_mean, beta_var, Flat())
21+
AdaMax(; alpha = 0.002, beta_mean = 0.9, beta_var = 0.999, epsilon = sqrt(eps(Float64))) =
22+
AdaMax(alpha, beta_mean, beta_var, epsilon, Flat())
2323
Base.summary(::AdaMax) = "AdaMax"
2424
function default_options(method::AdaMax)
2525
(; allow_f_increases = true, iterations=10_000)
2626
end
2727

2828

29-
mutable struct AdaMaxState{Tx, T, Tz, Tm, Tu, Ti} <: AbstractOptimizerState
29+
mutable struct AdaMaxState{Tx, T, Tm, Tu, Ti} <: AbstractOptimizerState
3030
x::Tx
3131
x_previous::Tx
3232
f_x_previous::T
3333
s::Tx
34-
z::Tz
3534
m::Tm
3635
u::Tu
3736
iter::Ti
@@ -45,17 +44,15 @@ function initial_state(method::AdaMax, options, d, initial_x::AbstractArray{T})
4544
value_gradient!!(d, initial_x)
4645
α, β₁, β₂ = method.α, method.β₁, method.β₂
4746

48-
z = copy(initial_x)
4947
m = copy(gradient(d))
50-
u = fill(zero(m[1]^2), length(m))
48+
u = zero(m)
5149
a = 1 - β₁
5250
iter = 0
5351

5452
AdaMaxState(initial_x, # Maintain current state in state.x
5553
copy(initial_x), # Maintain previous state in state.x_previous
5654
real(T(NaN)), # Store previous f in state.f_x_previous
5755
similar(initial_x), # Maintain current search direction in state.s
58-
z,
5956
m,
6057
u,
6158
iter)
@@ -64,22 +61,14 @@ end
6461
function update_state!(d, state::AdaMaxState{T}, method::AdaMax) where T
6562
state.iter = state.iter+1
6663
value_gradient!(d, state.x)
67-
α, β₁, β₂ = method.α, method.β₁, method.β₂
64+
α, β₁, β₂, ϵ = method.α, method.β₁, method.β₂, method.ϵ
6865
a = 1 - β₁
69-
m, u, z = state.m, state.u, state.z
66+
m, u = state.m, state.u
7067

7168
m .= β₁ .* m .+ a .* gradient(d)
72-
u .= max.(β₂ .* u, abs.(gradient(d)))
73-
z .= z .-./ (1 - β₁^state.iter)) .* m ./ u
74-
for _i in eachindex(z)
75-
# since m and u start at 0, this can happen if the initial gradient is exactly 0
76-
# rosenbrock(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2
77-
# optimize(rosenbrock, zeros(2), AdaMax(), Optim.Options(iterations=10000))
78-
if isnan(z[_i])
79-
z[_i] = state.x[_i]
80-
end
81-
end
82-
state.x .= z
69+
u .= max.(ϵ, max.(β₂ .* u, abs.(gradient(d)))) # I know it's not there in the paper but if m and u start at 0 for some element... NaN occurs next
70+
71+
@. state.x = state.x -/ (1 - β₁^state.iter)) * m / u
8372
# Update current position # x = x + alpha * s
8473
false # break on linesearch error
8574
end

test/multivariate/solvers/first_order/adam_adamax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
run_optim_tests(Adam();
2121
skip = skip,
22-
show_name = true)
22+
show_name = debug_printing)
2323
end
2424
@testset "AdaMax" begin
2525
f(x) = x[1]^4
@@ -42,6 +42,6 @@ end
4242
)
4343
run_optim_tests(AdaMax();
4444
skip = skip,
45-
show_name=true,
45+
show_name=debug_printing,
4646
iteration_exceptions = (("Trigonometric", 1_000_000,),))
4747
end

0 commit comments

Comments
 (0)