Skip to content

Commit dd8fbdc

Browse files
committed
Minor correction to optimizer & add exp lr scheuler
1 parent bc4a5c1 commit dd8fbdc

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/nn/adam.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ function _step!(opt::Adam, θ::T, ∇::T, i; dispose::Bool) where T <: AbstractA
8383
"instead: `$(size(θ))` vs `$(size(∇))`."))
8484

8585
# Debiasing.
86-
lr = opt.lr * (1f0 - opt.β2^opt.current_step) / (1f0 - opt.β1^opt.current_step)
8786
adam_step_kernel!(get_backend(opt))(
8887
opt.μ[i], opt.ν[i], θ, ∇,
89-
opt.lr, opt.β1, opt.β2, opt.ϵ; ndrange=length(θ))
88+
opt.lr, opt.β1, opt.β2, opt.ϵ, opt.current_step; ndrange=length(θ))
9089

9190
dispose && KA.unsafe_free!(∇)
9291

@@ -95,16 +94,29 @@ end
9594

9695
@kernel function adam_step_kernel!(
9796
μ, ν, Θ, @Const(∇), lr::Float32,
98-
β1::Float32, β2::Float32, ϵ::Float32,
97+
β1::Float32, β2::Float32, ϵ::Float32, step::UInt32,
9998
)
10099
i = @index(Global)
101100
@inbounds ∇ᵢ = ∇[i]
102-
@inbounds ωᵢ = Θ[i]
101+
∇ᵢ² = ∇ᵢ^2
103102

104-
∇ᵢ² = ∇ᵢ * ∇ᵢ
105103
@inbounds μᵢ = μ[i] = β1 * μ[i] + (1f0 - β1) * ∇ᵢ
106104
@inbounds νᵢ = ν[i] = β2 * ν[i] + (1f0 - β2) * ∇ᵢ²
107105

108-
@inbounds Θ[i] = ωᵢ - (lr * μᵢ) / (νᵢ + ϵ)
106+
# Debiasing.
107+
μ̂ = μᵢ / (1f0 - β1^step)
108+
ν̂ = νᵢ / (1f0 - β2^step)
109+
110+
@inbounds ωᵢ = Θ[i]
111+
@inbounds Θ[i] = ωᵢ - lr * μ̂ / (ν̂ + ϵ)
109112
end
110113

114+
function exp_scheduler(lr_start::Float32, lr_end::Float32, steps::Int)
115+
function _scheduler(step::Int)
116+
(step < 0 || (lr_start 0f0 && lr_end 0f0)) && return 0f0
117+
118+
t = clamp(Float32(step / steps), 0f0, 1f0)
119+
return exp(log(lr_start) * (1 - t) + log(lr_end) * t)
120+
end
121+
return _scheduler
122+
end

0 commit comments

Comments
 (0)