Skip to content

Commit 2a1b2ed

Browse files
make docstrings consistent (#187)
* fix docstrings * address the review comments
1 parent bb71298 commit 2a1b2ed

File tree

2 files changed

+54
-38
lines changed

2 files changed

+54
-38
lines changed

src/rules.jl

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
"""
1010
Descent(η = 1f-1)
11-
Descent(; eta)
11+
Descent(; [eta])
1212
1313
Classic gradient descent optimiser with learning rate `η`.
1414
For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
@@ -20,12 +20,13 @@ For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
2020
struct Descent{T} <: AbstractRule
2121
eta::T
2222
end
23+
2324
Descent(; eta = 1f-1) = Descent(eta)
2425

2526
init(o::Descent, x::AbstractArray) = nothing
2627

2728
function apply!(o::Descent, state, x, dx)
28-
η = convert(float(eltype(x)), o.eta)
29+
η = ofeltype(x, o.eta)
2930

3031
return state, @lazy dx * η # @lazy creates a Broadcasted, will later fuse with x .= x .- dx
3132
end
@@ -64,6 +65,8 @@ end
6465

6566
"""
6667
Nesterov(η = 0.001, ρ = 0.9)
68+
Nesterov(; [eta, rho])
69+
6770
6871
Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
6972
@@ -153,27 +156,26 @@ end
153156

154157
"""
155158
Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
159+
Rprop(; [eta, ell, gamma])
156160
157161
Optimizer using the
158162
[Rprop](https://ieeexplore.ieee.org/document/298623) algorithm. A full-batch
159163
learning algorithm that depends only on the sign of the gradient.
160164
161165
# Parameters
162-
- Learning rate (`η`): Amount by which gradients are discounted before updating
166+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
163167
the weights.
164168
165-
- Scaling factors (`ℓ::Tuple`): Multiplicative increase and decrease factors.
169+
- Scaling factors (`ℓ::Tuple == ell`): Multiplicative increase and decrease factors.
166170
167-
- Step sizes (`Γ::Tuple`): Mminimal and maximal allowed step sizes.
171+
- Step sizes (`Γ::Tuple == gamma`): Mminimal and maximal allowed step sizes.
168172
"""
169-
struct Rprop{T} <: AbstractRule
170-
eta::T
171-
ell::Tuple{T,T}
172-
gamma::Tuple{T,T}
173+
@def struct Rprop <: AbstractRule
174+
eta = 1f-3
175+
ell = (5f-1, 1.2f0)
176+
gamma = (1f-6, 50f0)
173177
end
174178

175-
Rprop= 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0)) = Rprop{typeof(η)}(η, ℓ, Γ)
176-
177179
init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
178180

179181
function apply!(o::Rprop, state, x::AbstractArray{T}, dx) where T
@@ -193,15 +195,16 @@ end
193195

194196
"""
195197
Adam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
198+
Adam(; [eta, beta, epsilon])
196199
197200
[Adam](https://arxiv.org/abs/1412.6980) optimiser.
198201
199202
# Parameters
200-
- Learning rate (`η`): Amount by which gradients are discounted before updating
203+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
201204
the weights.
202-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
205+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
203206
second (β2) momentum estimate.
204-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
207+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
205208
(no need to change default)
206209
"""
207210
@def struct Adam <: AbstractRule
@@ -225,12 +228,13 @@ end
225228

226229
"""
227230
Lion(η = 0.001, β = (0.9, 0.999))
231+
Lion(; [eta, beta])
228232
229233
[Lion](https://arxiv.org/abs/2302.06675) optimiser.
230234
231235
# Parameters
232-
- Learning rate (`η`): Magnitude by which gradients are updating the weights.
233-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
236+
- Learning rate (`η == eta`): Magnitude by which gradients are updating the weights.
237+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
234238
second (β2) momentum estimate.
235239
"""
236240
@def struct Lion <: AbstractRule
@@ -254,15 +258,16 @@ end
254258

255259
"""
256260
RAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
261+
RAdam(; [eta, beta, epsilon])
257262
258263
[Rectified Adam](https://arxiv.org/abs/1908.03265) optimizer.
259264
260265
# Parameters
261-
- Learning rate (`η`): Amount by which gradients are discounted before updating
266+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
262267
the weights.
263-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
268+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
264269
second (β2) momentum estimate.
265-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
270+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
266271
(no need to change default)
267272
"""
268273
@def struct RAdam <: AbstractRule
@@ -294,15 +299,16 @@ end
294299

295300
"""
296301
AdaMax(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
302+
AdaMax(; [eta, beta, epsilon])
297303
298304
[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of Adam based on the ∞-norm.
299305
300306
# Parameters
301-
- Learning rate (`η`): Amount by which gradients are discounted before updating
307+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
302308
the weights.
303-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
309+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
304310
second (β2) momentum estimate.
305-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
311+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
306312
(no need to change default)
307313
"""
308314
@def struct AdaMax <: AbstractRule
@@ -326,16 +332,17 @@ end
326332

327333
"""
328334
OAdam(η = 0.001, β = (0.5, 0.9), ϵ = 1e-8)
335+
OAdam(; [eta, beta, epsilon])
329336
330337
[OAdam](https://arxiv.org/abs/1711.00141) (Optimistic Adam)
331338
is a variant of Adam adding an "optimistic" term suitable for adversarial training.
332339
333340
# Parameters
334-
- Learning rate (`η`): Amount by which gradients are discounted before updating
341+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
335342
the weights.
336-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
343+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
337344
second (β2) momentum estimate.
338-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
345+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
339346
(no need to change default)
340347
"""
341348
@def struct OAdam <: AbstractRule
@@ -361,15 +368,16 @@ end
361368

362369
"""
363370
AdaGrad(η = 0.1, ϵ = 1e-8)
371+
AdaGrad(; [eta, epsilon])
364372
365373
[AdaGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
366374
parameter specific learning rates based on how frequently it is updated.
367375
Parameters don't need tuning.
368376
369377
# Parameters
370-
- Learning rate (`η`): Amount by which gradients are discounted before updating
378+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
371379
the weights.
372-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
380+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
373381
(no need to change default)
374382
"""
375383
@def struct AdaGrad <: AbstractRule
@@ -391,14 +399,15 @@ end
391399

392400
"""
393401
AdaDelta(ρ = 0.9, ϵ = 1e-8)
402+
AdaDelta(; [rho, epsilon])
394403
395404
[AdaDelta](https://arxiv.org/abs/1212.5701) is a version of AdaGrad adapting its learning
396405
rate based on a window of past gradient updates.
397406
Parameters don't need tuning.
398407
399408
# Parameters
400-
- Rho (`ρ`): Factor by which the gradient is decayed at each time step.
401-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
409+
- Rho (`ρ == rho`): Factor by which the gradient is decayed at each time step.
410+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
402411
(no need to change default)
403412
"""
404413
@def struct AdaDelta <: AbstractRule
@@ -422,16 +431,17 @@ end
422431

423432
"""
424433
AMSGrad(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
434+
AMSGrad(; [eta, beta, epsilon])
425435
426436
The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the Adam
427437
optimiser. Parameters don't need tuning.
428438
429439
# Parameters
430-
- Learning rate (`η`): Amount by which gradients are discounted before updating
440+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
431441
the weights.
432-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
442+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
433443
second (β2) momentum estimate.
434-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
444+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
435445
(no need to change default)
436446
"""
437447
@def struct AMSGrad <: AbstractRule
@@ -457,16 +467,17 @@ end
457467

458468
"""
459469
NAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
470+
NAdam(; [eta, beta, epsilon])
460471
461472
[NAdam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of Adam.
462473
Parameters don't need tuning.
463474
464475
# Parameters
465-
- Learning rate (`η`): Amount by which gradients are discounted before updating
476+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
466477
the weights.
467-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
478+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
468479
second (β2) momentum estimate.
469-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
480+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
470481
(no need to change default)
471482
"""
472483
@def struct NAdam <: AbstractRule
@@ -515,16 +526,17 @@ AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
515526

516527
"""
517528
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
529+
AdaBelief(; [eta, beta, epsilon])
518530
519531
The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known
520532
Adam optimiser.
521533
522534
# Parameters
523-
- Learning rate (`η`): Amount by which gradients are discounted before updating
535+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
524536
the weights.
525-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
537+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
526538
second (β2) momentum estimate.
527-
- Machine epsilon (`ϵ::Float32`): Constant to prevent division by zero
539+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
528540
(no need to change default)
529541
"""
530542
@def struct AdaBelief <: AbstractRule
@@ -548,6 +560,7 @@ end
548560

549561
"""
550562
WeightDecay(λ = 5e-4)
563+
WeightDecay(; [lambda])
551564
552565
Implements ``L_2`` regularisation, also known as ridge regression,
553566
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
@@ -585,6 +598,7 @@ function adjust(r::WeightDecay; gamma = nothing, kw...)
585598

586599
"""
587600
SignDecay(λ = 1e-3)
601+
SignDecay(; [lambda])
588602
589603
Implements ``L_1`` regularisation, also known as LASSO regression,
590604
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
@@ -615,6 +629,7 @@ end
615629

616630
"""
617631
ClipGrad(δ = 10)
632+
ClipGrad(; [delta])
618633
619634
Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
620635

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
1313
f(v, (get(y, k, nothing) for y in ys)...)
1414
end
1515

16+
ofeltype(x, y) = convert(float(eltype(x)), y)

0 commit comments

Comments
 (0)