@@ -114,16 +114,16 @@ gradients by an estimate their variance, instead of their second moment.
114114- Keyword `centred` (or `centered`): Indicates whether to use centred variant
115115 of the algorithm.
116116"""
117- struct RMSProp <: AbstractRule
118- eta:: Float64
119- rho:: Float64
120- epsilon:: Float64
117+ struct RMSProp{Teta,Trho,Teps} <: AbstractRule
118+ eta:: Teta
119+ rho:: Trho
120+ epsilon:: Teps
121121 centred:: Bool
122122end
123123
124124function RMSProp(η, ρ = 0.9 , ϵ = 1e-8 ; centred:: Bool = false , centered:: Bool = false )
125125 η < 0 && throw(DomainError(η, " the learning rate cannot be negative" ))
126- RMSProp(η, ρ, ϵ , centred | centered)
126+ return RMSProp(float(η), float(ρ), float(ϵ) , centred | centered)
127127end
128128RMSProp(; eta = 0.001 , rho = 0.9 , epsilon = 1e-8 , kw... ) = RMSProp(eta, rho, epsilon; kw... )
129129
155155
156156
157157"""
158- Rprop(η = 1f -3, ℓ = (5f-1 , 1.2f0 ), Γ = (1f -6, 50f0 ))
158+ Rprop(η = 1e -3, ℓ = (0.5 , 1.2 ), Γ = (1e -6, 50.0 ))
159159 Rprop(; [eta, ell, gamma])
160160
161161Optimizer using the
@@ -171,9 +171,9 @@ learning algorithm that depends only on the sign of the gradient.
171171- Step sizes (`Γ::Tuple == gamma`): Mminimal and maximal allowed step sizes.
172172"""
173173@def struct Rprop <: AbstractRule
174- eta = 1f -3
175- ell = (5f-1 , 1.2f0 )
176- gamma = (1f -6 , 50f0 )
174+ eta = 1e -3
175+ ell = (0.5 , 1.2 )
176+ gamma = (1e -6 , 50.0 )
177177end
178178
179179init(o:: Rprop , x:: AbstractArray ) = (zero(x), onevalue(o. eta, x))
@@ -528,17 +528,17 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`]
528528 The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`.
529529 See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details.
530530"""
531- struct AdamW <: AbstractRule
532- eta:: Float64
533- beta:: Tuple{Float64, Float64}
534- lambda:: Float64
535- epsilon:: Float64
531+ struct AdamW{Teta,Tbeta <: Tuple ,Tlambda,Teps} <: AbstractRule
532+ eta:: Teta
533+ beta:: Tbeta
534+ lambda:: Tlambda
535+ epsilon:: Teps
536536 couple:: Bool
537537end
538538
539539function AdamW(η, β = (0.9 , 0.999 ), λ = 0.0 , ϵ = 1e-8 ; couple:: Bool = true )
540540 η < 0 && throw(DomainError(η, " the learning rate cannot be negative" ))
541- AdamW(η , β, λ, ϵ , couple)
541+ return AdamW(float(η) , β, float(λ), float(ϵ) , couple)
542542end
543543
544544AdamW(; eta = 0.001 , beta = (0.9 , 0.999 ), lambda= 0.0 , epsilon = 1e-8 , kw... ) =
@@ -704,12 +704,12 @@ Typically composed with other rules using [`OptimiserChain`](@ref).
704704
705705See also [`ClipGrad`](@ref).
706706"""
707- struct ClipNorm <: AbstractRule
708- omega:: Float64
709- p:: Float64
707+ struct ClipNorm{To,Tp} <: AbstractRule
708+ omega:: To
709+ p:: Tp
710710 throw:: Bool
711711end
712- ClipNorm(ω = 10 , p = 2 ; throw:: Bool = true ) = ClipNorm(ω, p , throw)
712+ ClipNorm(ω = 10 , p = 2 ; throw:: Bool = true ) = ClipNorm(float(ω), float(p) , throw)
713713
714714init(o:: ClipNorm , x:: AbstractArray ) = nothing
715715
0 commit comments