diff --git a/src/rules.jl b/src/rules.jl index f428666..3cd7743 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -499,7 +499,7 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T end """ - AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8; couple = true) + AdamW(η = 0.001, β = (0.9, 0.999), λ = 0.01, ϵ = 1e-8; couple = true) AdamW(; [eta, beta, lambda, epsilon, couple]) [AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its @@ -534,12 +534,12 @@ struct AdamW{Teta,Tbeta<:Tuple,Tlambda,Teps} <: AbstractRule couple::Bool end -function AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8; couple::Bool = true) +function AdamW(η, β = (0.9, 0.999), λ = 0.01, ϵ = 1e-8; couple::Bool = true) η < 0 && throw(DomainError(η, "the learning rate cannot be negative")) return AdamW(float(η), β, float(λ), float(ϵ), couple) end -AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.0, epsilon = 1e-8, kw...) = +AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.01, epsilon = 1e-8, kw...) = AdamW(eta, beta, lambda, epsilon; kw...) init(o::AdamW, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))