@@ -501,8 +501,8 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
501501end
502502
503503"""
504- AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
505- AdamW(; [eta, beta, lambda, epsilon])
504+ AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8; couple = true )
505+ AdamW(; [eta, beta, lambda, epsilon, couple ])
506506
507507[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
508508weight decay regularization.
@@ -516,12 +516,54 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`]
516516- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation.
517517- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
518518 (no need to change default)
519- """
520- AdamW (η, β = (0.9 , 0.999 ), λ = 0.0 , ϵ = 1e-8 ) =
521- OptimiserChain (Adam (η, β, ϵ), WeightDecay (λ))
519+ - Keyword `couple`: If `true`, the weight decay is coupled with the learning rate, as in pytorch's AdamW.
520+ This corresponds to an update of the form `x = x - η * (dx + λ * x)`, where `dx` is the
521+ update from Adam with learning rate 1.
522+ If `false`, the weight decay is decoupled from the learning rate, in the spirit of the original paper.
523+ This corresponds to an update of the form `x = x - η * dx - λ * x`.
524+ Default is `true`.
525+
526+ !!! warning "Breaking change in v0.4"
527+ With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation.
528+ The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`.
529+ See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details.
530+ """
531+ struct AdamW{T1,T2,T3,T4} <: AbstractRule
532+ eta:: T1
533+ beta:: T2
534+ epsilon:: T3
535+ lambda:: T4
536+ couple:: Bool
537+ end
538+
539+ function AdamW (η, β = (0.9 , 0.999 ), λ = 0.0 , ϵ = 1e-8 ; couple:: Bool = true )
540+ η < 0 && throw (DomainError (η, " the learning rate cannot be negative" ))
541+ AdamW (η, β, λ, ϵ, couple)
542+ end
543+
544+ AdamW (; eta = 0.001 , beta = (0.9 , 0.999 ), lambda= 0.0 , epsilon = 1e-8 , kw... ) =
545+ AdamW (eta, beta, lambda, epsilon; kw... )
546+
547+ init (o:: AdamW , x:: AbstractArray{T} ) where T = (zero (x), zero (x), T .(o. beta))
548+
549+ function apply! (o:: AdamW , state, x:: AbstractArray{T} , dx) where T
550+ η, β, ϵ, λ = T (o. eta), T .(o. beta), T (o. epsilon), T (o. lambda)
551+ mt, vt, βt = state
522552
523- AdamW (; eta = 0.001 , beta = (0.9 , 0.999 ), lambda = 0 , epsilon = 1e-8 ) =
524- OptimiserChain (Adam (eta, beta, epsilon), WeightDecay (lambda))
553+ # standard Adam update with learning rate eta=1
554+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
555+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * abs2 (dx)
556+ dx′ = @lazy mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ)
557+
558+ # apply learning rate and weight decay
559+ if o. couple
560+ dx′′ = @lazy η * (dx′ + λ * x)
561+ else
562+ dx′′ = @lazy η * dx′ + λ * x
563+ end
564+
565+ return (mt, vt, βt .* β), dx′′
566+ end
525567
526568"""
527569 AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
0 commit comments