Skip to content

Commit 38c9d62

Browse files
Add the option couple to AdamW and set the default to match pytorch (#188)
1 parent 4a78a55 commit 38c9d62

File tree

5 files changed

+58
-11
lines changed

5 files changed

+58
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.3.4"
4+
version = "0.4.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ This was written as the new training system for [Flux.jl](https://github.com/Flu
2525
and also used by [Lux.jl](https://github.com/avik-pal/Lux.jl).
2626
But it can be used separately on any array, or anything else understood by [Functors.jl](https://github.com/FluxML/Functors.jl).
2727

28+
29+
> [!WARNING]
30+
> With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation.
31+
> The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`.
32+
> See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details.
33+
2834
## Installation
2935

3036
```julia

src/rules.jl

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,8 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
501501
end
502502

503503
"""
504-
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
505-
AdamW(; [eta, beta, lambda, epsilon])
504+
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8; couple = true)
505+
AdamW(; [eta, beta, lambda, epsilon, couple])
506506
507507
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
508508
weight decay regularization.
@@ -516,12 +516,54 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`]
516516
- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation.
517517
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
518518
(no need to change default)
519-
"""
520-
AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
521-
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))
519+
- Keyword `couple`: If `true`, the weight decay is coupled with the learning rate, as in pytorch's AdamW.
520+
This corresponds to an update of the form `x = x - η * (dx + λ * x)`, where `dx` is the
521+
update from Adam with learning rate 1.
522+
If `false`, the weight decay is decoupled from the learning rate, in the spirit of the original paper.
523+
This corresponds to an update of the form `x = x - η * dx - λ * x`.
524+
Default is `true`.
525+
526+
!!! warning "Breaking change in v0.4"
527+
With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation.
528+
The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`.
529+
See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details.
530+
"""
531+
struct AdamW{T1,T2,T3,T4} <: AbstractRule
532+
eta::T1
533+
beta::T2
534+
epsilon::T3
535+
lambda::T4
536+
couple::Bool
537+
end
538+
539+
function AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8; couple::Bool = true)
540+
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
541+
AdamW(η, β, λ, ϵ, couple)
542+
end
543+
544+
AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.0, epsilon = 1e-8, kw...) =
545+
AdamW(eta, beta, lambda, epsilon; kw...)
546+
547+
init(o::AdamW, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
548+
549+
function apply!(o::AdamW, state, x::AbstractArray{T}, dx) where T
550+
η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda)
551+
mt, vt, βt = state
522552

523-
AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
524-
OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda))
553+
# standard Adam update with learning rate eta=1
554+
@.. mt = β[1] * mt + (1 - β[1]) * dx
555+
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
556+
dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
557+
558+
# apply learning rate and weight decay
559+
if o.couple
560+
dx′′ = @lazy η * (dx′ + λ * x)
561+
else
562+
dx′′ = @lazy η * dx′ + λ * x
563+
end
564+
565+
return (mt, vt, βt .* β), dx′′
566+
end
525567

526568
"""
527569
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)

test/rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ RULES = [
1515
OptimiserChain(ClipGrad(0.5), Momentum()),
1616
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),
1717
# Not the default:
18-
RMSProp(centred = true),
18+
RMSProp(centred = true), AdamW(couple=false),
1919
]
2020

2121
name(o) = typeof(o).name.name # just for printing testset headings

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ end
332332

333333
@testset "keyword arguments" begin
334334
@test Nesterov(rho=0.8, eta=0.1) === Nesterov(0.1, 0.8)
335-
@test AdamW(lambda=0.3).opts[1] == Adam()
336-
@test AdamW(lambda=0.3).opts[2] == WeightDecay(0.3)
335+
@test AdamW(lambda=0.3, eta=0.1) == AdamW(0.1, (0.9, 0.999), 0.3, 1.0e-8)
337336
end
338337

339338
@testset "forgotten gradient" begin

0 commit comments

Comments
 (0)