@@ -84,37 +84,55 @@ function apply!(o::Nesterov, state, x, dx)
84
84
end
85
85
86
86
"""
87
- RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)))
87
+ RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred = false )
88
88
89
89
Optimizer using the
90
90
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
91
91
algorithm. Often a good choice for recurrent networks. Parameters other than learning rate
92
92
generally don't need tuning.
93
93
94
+ [Centred RMSProp](http://arxiv.org/abs/1308.08500) is a variant which normalises
95
+ gradients by an estimate their variance, instead of their second moment.
96
+
94
97
# Parameters
95
98
- Learning rate (`η`): Amount by which gradients are discounted before updating
96
99
the weights.
97
100
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
98
101
prominent direction, in effect dampening oscillations.
99
102
- Machine epsilon (`ϵ`): Constant to prevent division by zero
100
103
(no need to change default)
104
+ - Keyword `centred` (or `centered`): Indicates whether to use centred variant
105
+ of the algorithm.
101
106
"""
102
107
struct RMSProp{T}
103
108
eta:: T
104
109
rho:: T
105
110
epsilon:: T
111
+ centred:: Bool
106
112
end
107
- RMSProp (η = 1f-3 , ρ = 9f-1 , ϵ = eps (typeof (η))) = RMSProp {typeof(η)} (η, ρ, ϵ)
113
+ RMSProp (η = 1f-3 , ρ = 9f-1 , ϵ = eps (typeof (η)); centred:: Bool = false , centered:: Bool = false ) =
114
+ RMSProp {typeof(η)} (η, ρ, ϵ, centred | centered)
108
115
109
- init (o:: RMSProp , x:: AbstractArray ) = zero (x)
116
+ init (o:: RMSProp , x:: AbstractArray ) = ( zero (x), o . centred ? zero (x) : false )
110
117
111
118
function apply! (o:: RMSProp , state, x, dx)
112
- η, ρ, ϵ, acc = o. eta, o. rho, o. epsilon, state
119
+ η, ρ, ϵ = o. eta, o. rho, o. epsilon
120
+ quad, lin = state
113
121
114
- @. . acc = ρ * acc + (1 - ρ) * abs2 (dx)
115
- dx′ = @lazy dx * (η / (sqrt (acc) + ϵ))
122
+ @. . quad = ρ * quad + (1 - ρ) * abs2 (dx)
123
+ if o. centred
124
+ @. . lin = ρ * lin + (1 - ρ) * dx
125
+ end
126
+ dx′ = @lazy dx * η / (sqrt (quad - abs2 (lin)) + ϵ)
116
127
117
- return acc, dx′
128
+ return (quad, lin), dx′
129
+ end
130
+
131
+ function Base. show (io:: IO , o:: RMSProp )
132
+ show (io, typeof (o))
133
+ print (io, " (" )
134
+ join (io, [o. eta, o. rho, o. epsilon], " , " )
135
+ print (io, " ; centred = " , o. centred, " )" )
118
136
end
119
137
120
138
"""
0 commit comments