Skip to content

Commit 0326f25

Browse files
committed
doc
1 parent 8b82128 commit 0326f25

File tree

1 file changed

+33
-18
lines changed

1 file changed

+33
-18
lines changed

src/rules.jl

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -138,32 +138,47 @@ end
138138

139139

140140
"""
141-
RProp
141+
Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
142+
143+
Optimizer using the
144+
[Rprop](https://ieeexplore.ieee.org/document/298623) algorithm. An full-batch
145+
learning algorithm that depends only on the sign of the gradient.
146+
147+
# Parameters
148+
- Learning rate (`η`): Amount by which gradients are discounted before updating
149+
the weights.
150+
151+
- Scaling factors (`ℓ::Tuple`): Multiplicative increase and decrease factors.
152+
153+
- Step sizes (`Γ::Tuple`): Mminimal and maximal allowed step sizes.
142154
"""
143-
struct RProp{T} <: AbstractRule
144-
eta::T
145-
ell::Tuple{T,2}
146-
gamma::Tuple{T,2}
155+
struct Rprop{T} <: AbstractRule
156+
eta::T
157+
ell::Tuple{T,T}
158+
gamma::Tuple{T,T}
147159
end
148160

149-
RProp= 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0)) = RProp{typeof(η)}(η, ℓ, Γ)
161+
Rprop= 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0)) = Rprop{typeof(η)}(η, ℓ, Γ)
150162

151-
init(o::RProp, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
163+
init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
152164

153-
function apply!(o::RProp, state, x, dx)
154-
ℓ, Γ = o.ell, o.gamma
155-
g₀, η₀ = state
165+
function apply!(o::Rprop, state, x, dx)
166+
ℓ, Γ = o.ell, o.gamma
167+
g₀, η₀ = state
156168

157-
@.. ind = g₀ * dx
169+
signs = g₀ .* dx
170+
signs[signs .> 0] .= ℓ[2]
171+
signs[signs .< 0] .= ℓ[1]
172+
signs[signs .== 0] .= one(eltype(signs))
158173

159-
g₁ = map(i -> ind[i] < 0f0 ? zero(g₀[i]) : g₀[i], CartesianIndices(g₀))
160-
η₁ = map(i -> ind[i] > zero(ind[i]) ? min(η₀[i] * ℓ[2], Γ[2]) :
161-
ind[i] < zero(ind[i]) ? max(η₀[i] * ℓ[1], Γ[1]) : η₀[i],
162-
CartesianIndices(η₀))
174+
η₁ = clamp.(η₀ .* signs, Γ[1], Γ[2])
163175

164-
dx' = @lazy dx * sign(g₁)
165-
166-
return (g₁, η₁), dx'
176+
g₁ = copy(dx)
177+
g₁[signs .== ℓ[1]] .= zero(eltype(g₁))
178+
179+
dx′ = @lazy η₁ * sign(g₁)
180+
181+
return (g₁, η₁), dx′
167182
end
168183

169184

0 commit comments

Comments
 (0)