@@ -83,10 +83,9 @@ function _step!(opt::Adam, θ::T, ∇::T, i; dispose::Bool) where T <: AbstractA
8383 " instead: `$(size(θ)) ` vs `$(size(∇)) `." ))
8484
8585 # Debiasing.
86- lr = opt. lr * √ (1f0 - opt. β2^ opt. current_step) / (1f0 - opt. β1^ opt. current_step)
8786 adam_step_kernel!(get_backend(opt))(
8887 opt. μ[i], opt. ν[i], θ, ∇,
89- opt. lr, opt. β1, opt. β2, opt. ϵ; ndrange= length(θ))
88+ opt. lr, opt. β1, opt. β2, opt. ϵ, opt . current_step ; ndrange= length(θ))
9089
9190 dispose && KA. unsafe_free!(∇)
9291
9594
9695@kernel function adam_step_kernel!(
9796 μ, ν, Θ, @Const(∇), lr:: Float32 ,
98- β1:: Float32 , β2:: Float32 , ϵ:: Float32 ,
97+ β1:: Float32 , β2:: Float32 , ϵ:: Float32 , step :: UInt32 ,
9998)
10099 i = @index(Global)
101100 @inbounds ∇ᵢ = ∇[i]
102- @inbounds ωᵢ = Θ[i]
101+ ∇ᵢ² = ∇ᵢ ^ 2
103102
104- ∇ᵢ² = ∇ᵢ * ∇ᵢ
105103 @inbounds μᵢ = μ[i] = β1 * μ[i] + (1f0 - β1) * ∇ᵢ
106104 @inbounds νᵢ = ν[i] = β2 * ν[i] + (1f0 - β2) * ∇ᵢ²
107105
108- @inbounds Θ[i] = ωᵢ - (lr * μᵢ) / (√ νᵢ + ϵ)
106+ # Debiasing.
107+ μ̂ = μᵢ / (1f0 - β1^ step)
108+ ν̂ = νᵢ / (1f0 - β2^ step)
109+
110+ @inbounds ωᵢ = Θ[i]
111+ @inbounds Θ[i] = ωᵢ - lr * μ̂ / (√ ν̂ + ϵ)
109112end
110113
114+ function exp_scheduler(lr_start:: Float32 , lr_end:: Float32 , steps:: Int )
115+ function _scheduler(step:: Int )
116+ (step < 0 || (lr_start ≈ 0f0 && lr_end ≈ 0f0 )) && return 0f0
117+
118+ t = clamp(Float32(step / steps), 0f0 , 1f0 )
119+ return exp(log(lr_start) * (1 - t) + log(lr_end) * t)
120+ end
121+ return _scheduler
122+ end
0 commit comments