88
99"""
1010 Descent(η = 1f-1)
11- Descent(; eta)
11+ Descent(; [ eta] )
1212
1313Classic gradient descent optimiser with learning rate `η`.
1414For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
@@ -20,12 +20,13 @@ For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
2020struct Descent{T} <: AbstractRule
2121 eta:: T
2222end
23+
2324Descent (; eta = 1f-1 ) = Descent (eta)
2425
2526init (o:: Descent , x:: AbstractArray ) = nothing
2627
2728function apply! (o:: Descent , state, x, dx)
28- η = convert ( float ( eltype (x)) , o. eta)
29+ η = ofeltype (x , o. eta)
2930
3031 return state, @lazy dx * η # @lazy creates a Broadcasted, will later fuse with x .= x .- dx
3132end
6465
6566"""
6667 Nesterov(η = 0.001, ρ = 0.9)
68+ Nesterov(; [eta, rho])
69+
6770
6871Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
6972
@@ -153,27 +156,26 @@ end
153156
154157"""
155158 Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
159+ Rprop(; [eta, ell, gamma])
156160
157161Optimizer using the
158162[Rprop](https://ieeexplore.ieee.org/document/298623) algorithm. A full-batch
159163learning algorithm that depends only on the sign of the gradient.
160164
161165# Parameters
162- - Learning rate (`η`): Amount by which gradients are discounted before updating
166+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
163167 the weights.
164168
165- - Scaling factors (`ℓ::Tuple`): Multiplicative increase and decrease factors.
169+ - Scaling factors (`ℓ::Tuple == ell `): Multiplicative increase and decrease factors.
166170
167- - Step sizes (`Γ::Tuple`): Mminimal and maximal allowed step sizes.
171+ - Step sizes (`Γ::Tuple == gamma `): Mminimal and maximal allowed step sizes.
168172"""
169- struct Rprop{T} <: AbstractRule
170- eta:: T
171- ell:: Tuple{T,T}
172- gamma:: Tuple{T,T}
173+ @def struct Rprop <: AbstractRule
174+ eta = 1f-3
175+ ell = ( 5f-1 , 1.2f0 )
176+ gamma = ( 1f-6 , 50f0 )
173177end
174178
175- Rprop (η = 1f-3 , ℓ = (5f-1 , 1.2f0 ), Γ = (1f-6 , 50f0 )) = Rprop {typeof(η)} (η, ℓ, Γ)
176-
177179init (o:: Rprop , x:: AbstractArray ) = (zero (x), onevalue (o. eta, x))
178180
179181function apply! (o:: Rprop , state, x:: AbstractArray{T} , dx) where T
@@ -193,15 +195,16 @@ end
193195
194196"""
195197 Adam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
198+ Adam(; [eta, beta, epsilon])
196199
197200[Adam](https://arxiv.org/abs/1412.6980) optimiser.
198201
199202# Parameters
200- - Learning rate (`η`): Amount by which gradients are discounted before updating
203+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
201204 the weights.
202- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
205+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
203206 second (β2) momentum estimate.
204- - Machine epsilon (`ϵ`): Constant to prevent division by zero
207+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
205208 (no need to change default)
206209"""
207210@def struct Adam <: AbstractRule
@@ -225,12 +228,13 @@ end
225228
226229"""
227230 Lion(η = 0.001, β = (0.9, 0.999))
231+ Lion(; [eta, beta])
228232
229233[Lion](https://arxiv.org/abs/2302.06675) optimiser.
230234
231235# Parameters
232- - Learning rate (`η`): Magnitude by which gradients are updating the weights.
233- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
236+ - Learning rate (`η == eta `): Magnitude by which gradients are updating the weights.
237+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
234238 second (β2) momentum estimate.
235239"""
236240@def struct Lion <: AbstractRule
@@ -254,15 +258,16 @@ end
254258
255259"""
256260 RAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
261+ RAdam(; [eta, beta, epsilon])
257262
258263[Rectified Adam](https://arxiv.org/abs/1908.03265) optimizer.
259264
260265# Parameters
261- - Learning rate (`η`): Amount by which gradients are discounted before updating
266+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
262267 the weights.
263- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
268+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
264269 second (β2) momentum estimate.
265- - Machine epsilon (`ϵ`): Constant to prevent division by zero
270+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
266271 (no need to change default)
267272"""
268273@def struct RAdam <: AbstractRule
@@ -294,15 +299,16 @@ end
294299
295300"""
296301 AdaMax(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
302+ AdaMax(; [eta, beta, epsilon])
297303
298304[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of Adam based on the ∞-norm.
299305
300306# Parameters
301- - Learning rate (`η`): Amount by which gradients are discounted before updating
307+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
302308 the weights.
303- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
309+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
304310 second (β2) momentum estimate.
305- - Machine epsilon (`ϵ`): Constant to prevent division by zero
311+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
306312 (no need to change default)
307313"""
308314@def struct AdaMax <: AbstractRule
@@ -326,16 +332,17 @@ end
326332
327333"""
328334 OAdam(η = 0.001, β = (0.5, 0.9), ϵ = 1e-8)
335+ OAdam(; [eta, beta, epsilon])
329336
330337[OAdam](https://arxiv.org/abs/1711.00141) (Optimistic Adam)
331338is a variant of Adam adding an "optimistic" term suitable for adversarial training.
332339
333340# Parameters
334- - Learning rate (`η`): Amount by which gradients are discounted before updating
341+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
335342 the weights.
336- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
343+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
337344 second (β2) momentum estimate.
338- - Machine epsilon (`ϵ`): Constant to prevent division by zero
345+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
339346 (no need to change default)
340347"""
341348@def struct OAdam <: AbstractRule
@@ -361,15 +368,16 @@ end
361368
362369"""
363370 AdaGrad(η = 0.1, ϵ = 1e-8)
371+ AdaGrad(; [eta, epsilon])
364372
365373[AdaGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
366374parameter specific learning rates based on how frequently it is updated.
367375Parameters don't need tuning.
368376
369377# Parameters
370- - Learning rate (`η`): Amount by which gradients are discounted before updating
378+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
371379 the weights.
372- - Machine epsilon (`ϵ`): Constant to prevent division by zero
380+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
373381 (no need to change default)
374382"""
375383@def struct AdaGrad <: AbstractRule
@@ -391,14 +399,15 @@ end
391399
392400"""
393401 AdaDelta(ρ = 0.9, ϵ = 1e-8)
402+ AdaDelta(; [rho, epsilon])
394403
395404[AdaDelta](https://arxiv.org/abs/1212.5701) is a version of AdaGrad adapting its learning
396405rate based on a window of past gradient updates.
397406Parameters don't need tuning.
398407
399408# Parameters
400- - Rho (`ρ`): Factor by which the gradient is decayed at each time step.
401- - Machine epsilon (`ϵ`): Constant to prevent division by zero
409+ - Rho (`ρ == rho `): Factor by which the gradient is decayed at each time step.
410+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
402411 (no need to change default)
403412"""
404413@def struct AdaDelta <: AbstractRule
@@ -422,16 +431,17 @@ end
422431
423432"""
424433 AMSGrad(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
434+ AMSGrad(; [eta, beta, epsilon])
425435
426436The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the Adam
427437optimiser. Parameters don't need tuning.
428438
429439# Parameters
430- - Learning rate (`η`): Amount by which gradients are discounted before updating
440+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
431441 the weights.
432- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
442+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
433443 second (β2) momentum estimate.
434- - Machine epsilon (`ϵ`): Constant to prevent division by zero
444+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
435445 (no need to change default)
436446"""
437447@def struct AMSGrad <: AbstractRule
@@ -457,16 +467,17 @@ end
457467
458468"""
459469 NAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
470+ NAdam(; [eta, beta, epsilon])
460471
461472[NAdam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of Adam.
462473Parameters don't need tuning.
463474
464475# Parameters
465- - Learning rate (`η`): Amount by which gradients are discounted before updating
476+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
466477 the weights.
467- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
478+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
468479 second (β2) momentum estimate.
469- - Machine epsilon (`ϵ`): Constant to prevent division by zero
480+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
470481 (no need to change default)
471482"""
472483@def struct NAdam <: AbstractRule
@@ -515,16 +526,17 @@ AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
515526
516527"""
517528 AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
529+ AdaBelief(; [eta, beta, epsilon])
518530
519531The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known
520532Adam optimiser.
521533
522534# Parameters
523- - Learning rate (`η`): Amount by which gradients are discounted before updating
535+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
524536 the weights.
525- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
537+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
526538 second (β2) momentum estimate.
527- - Machine epsilon (`ϵ::Float32 `): Constant to prevent division by zero
539+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
528540 (no need to change default)
529541"""
530542@def struct AdaBelief <: AbstractRule
548560
549561"""
550562 WeightDecay(λ = 5e-4)
563+ WeightDecay(; [lambda])
551564
552565Implements ``L_2`` regularisation, also known as ridge regression,
553566when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
@@ -585,6 +598,7 @@ function adjust(r::WeightDecay; gamma = nothing, kw...)
585598
586599"""
587600 SignDecay(λ = 1e-3)
601+ SignDecay(; [lambda])
588602
589603Implements ``L_1`` regularisation, also known as LASSO regression,
590604when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
615629
616630"""
617631 ClipGrad(δ = 10)
632+ ClipGrad(; [delta])
618633
619634Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
620635
0 commit comments