Skip to content

Commit 2bf9e60

Browse files
authored
Centred RMSProp (#51)
* Centred RMSProp * try making this a keyword * description, show * add a d to keyword * fixup * require recent Zygote
1 parent 33c8144 commit 2bf9e60

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
ChainRulesCore = "1"
1515
Functors = "0.2.8"
16+
Zygote = "0.6.40"
1617
julia = "1.6"
1718

1819
[extras]

src/rules.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,37 +84,55 @@ function apply!(o::Nesterov, state, x, dx)
8484
end
8585

8686
"""
87-
RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)))
87+
RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred = false)
8888
8989
Optimizer using the
9090
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
9191
algorithm. Often a good choice for recurrent networks. Parameters other than learning rate
9292
generally don't need tuning.
9393
94+
[Centred RMSProp](http://arxiv.org/abs/1308.08500) is a variant which normalises
95+
gradients by an estimate their variance, instead of their second moment.
96+
9497
# Parameters
9598
- Learning rate (`η`): Amount by which gradients are discounted before updating
9699
the weights.
97100
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
98101
prominent direction, in effect dampening oscillations.
99102
- Machine epsilon (`ϵ`): Constant to prevent division by zero
100103
(no need to change default)
104+
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
105+
of the algorithm.
101106
"""
102107
struct RMSProp{T}
103108
eta::T
104109
rho::T
105110
epsilon::T
111+
centred::Bool
106112
end
107-
RMSProp= 1f-3, ρ = 9f-1, ϵ = eps(typeof(η))) = RMSProp{typeof(η)}(η, ρ, ϵ)
113+
RMSProp= 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred::Bool = false, centered::Bool = false) =
114+
RMSProp{typeof(η)}(η, ρ, ϵ, centred | centered)
108115

109-
init(o::RMSProp, x::AbstractArray) = zero(x)
116+
init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)
110117

111118
function apply!(o::RMSProp, state, x, dx)
112-
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
119+
η, ρ, ϵ = o.eta, o.rho, o.epsilon
120+
quad, lin = state
113121

114-
@.. acc = ρ * acc + (1 - ρ) * abs2(dx)
115-
dx′ = @lazy dx */ (sqrt(acc) + ϵ))
122+
@.. quad = ρ * quad + (1 - ρ) * abs2(dx)
123+
if o.centred
124+
@.. lin = ρ * lin + (1 - ρ) * dx
125+
end
126+
dx′ = @lazy dx * η / (sqrt(quad - abs2(lin)) + ϵ)
116127

117-
return acc, dx′
128+
return (quad, lin), dx′
129+
end
130+
131+
function Base.show(io::IO, o::RMSProp)
132+
show(io, typeof(o))
133+
print(io, "(")
134+
join(io, [o.eta, o.rho, o.epsilon], ", ")
135+
print(io, "; centred = ", o.centred, ")")
118136
end
119137

120138
"""

test/rules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ RULES = [
1414
OptimiserChain(ClipNorm(), Adam(0.001)),
1515
OptimiserChain(ClipGrad(0.5), Momentum()),
1616
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),
17+
# Not the default:
18+
RMSProp(centred = true),
1719
]
1820

1921
name(o) = typeof(o).name.name # just for printing testset headings
2022
name(o::OptimiserChain) = join(name.(o.opts), "")
23+
name(o::RMSProp) = o.centred ? "RMSProp(centred = true)" : :RMSProp
2124

2225
LOG = Dict() # for debugging these testsets, this makes it easy to plot each optimiser's loss
2326

0 commit comments

Comments
 (0)