@@ -138,32 +138,47 @@ end
138
138
139
139
140
140
"""
141
- RProp
141
+ Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
142
+
143
+ Optimizer using the
144
+ [Rprop](https://ieeexplore.ieee.org/document/298623) algorithm. An full-batch
145
+ learning algorithm that depends only on the sign of the gradient.
146
+
147
+ # Parameters
148
+ - Learning rate (`η`): Amount by which gradients are discounted before updating
149
+ the weights.
150
+
151
+ - Scaling factors (`ℓ::Tuple`): Multiplicative increase and decrease factors.
152
+
153
+ - Step sizes (`Γ::Tuple`): Mminimal and maximal allowed step sizes.
142
154
"""
143
- struct RProp {T} <: AbstractRule
144
- eta:: T
145
- ell:: Tuple{T,2 }
146
- gamma:: Tuple{T,2 }
155
+ struct Rprop {T} <: AbstractRule
156
+ eta:: T
157
+ ell:: Tuple{T,T }
158
+ gamma:: Tuple{T,T }
147
159
end
148
160
149
- RProp (η = 1f-3 , ℓ = (5f-1 , 1.2f0 ), Γ = (1f-6 , 50f0 )) = RProp {typeof(η)} (η, ℓ, Γ)
161
+ Rprop (η = 1f-3 , ℓ = (5f-1 , 1.2f0 ), Γ = (1f-6 , 50f0 )) = Rprop {typeof(η)} (η, ℓ, Γ)
150
162
151
- init (o:: RProp , x:: AbstractArray ) = (zero (x), onevalue (o. eta, x))
163
+ init (o:: Rprop , x:: AbstractArray ) = (zero (x), onevalue (o. eta, x))
152
164
153
- function apply! (o:: RProp , state, x, dx)
154
- ℓ, Γ = o. ell, o. gamma
155
- g₀, η₀ = state
165
+ function apply! (o:: Rprop , state, x, dx)
166
+ ℓ, Γ = o. ell, o. gamma
167
+ g₀, η₀ = state
156
168
157
- @. . ind = g₀ * dx
169
+ signs = g₀ .* dx
170
+ signs[signs .> 0 ] .= ℓ[2 ]
171
+ signs[signs .< 0 ] .= ℓ[1 ]
172
+ signs[signs .== 0 ] .= one (eltype (signs))
158
173
159
- g₁ = map (i -> ind[i] < 0f0 ? zero (g₀[i]) : g₀[i], CartesianIndices (g₀))
160
- η₁ = map (i -> ind[i] > zero (ind[i]) ? min (η₀[i] * ℓ[2 ], Γ[2 ]) :
161
- ind[i] < zero (ind[i]) ? max (η₀[i] * ℓ[1 ], Γ[1 ]) : η₀[i],
162
- CartesianIndices (η₀))
174
+ η₁ = clamp .(η₀ .* signs, Γ[1 ], Γ[2 ])
163
175
164
- dx' = @lazy dx * sign (g₁)
165
-
166
- return (g₁, η₁), dx'
176
+ g₁ = copy (dx)
177
+ g₁[signs .== ℓ[1 ]] .= zero (eltype (g₁))
178
+
179
+ dx′ = @lazy η₁ * sign (g₁)
180
+
181
+ return (g₁, η₁), dx′
167
182
end
168
183
169
184
0 commit comments