Skip to content

Commit c73dea7

Browse files
authored
Trivial cases of OptimiserChain (#43)
* trivial cases of OptimiserChain * and the missing show method * test a nested case too * always return an OptimiserChain * rm all optimisations? * change to show made on the website and forgotten * doc comment
1 parent 7c778bd commit c73dea7

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/rules.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,27 @@ end
516516
"""
517517
OptimiserChain(opts...)
518518
519-
Compose a chain (sequence) of optimisers so that each `opt` in `opts`
520-
updates the gradient in the order specified.
519+
Compose a sequence of optimisers so that each `opt` in `opts`
520+
updates the gradient, in the order specified.
521+
522+
With an empty sequence, `OptimiserChain()` is the identity,
523+
so `update!` will subtract the full gradient from the parameters.
524+
This is equivalent to `Descent(1)`.
525+
526+
# Example
527+
```jldoctest
528+
julia> o = OptimiserChain(ClipGrad(1), Descent(0.1));
529+
530+
julia> m = ([0,0,0],);
531+
532+
julia> s = Optimisers.setup(o, m)
533+
(Leaf(OptimiserChain(ClipGrad{Int64}(1), Descent{Float64}(0.1)), [nothing, nothing]),)
534+
535+
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
536+
([-0.03, -0.1, -0.1],)
537+
```
521538
"""
522-
struct OptimiserChain{O}
539+
struct OptimiserChain{O<:Tuple}
523540
opts::O
524541
end
525542
OptimiserChain(opts...) = OptimiserChain(opts)
@@ -534,3 +551,9 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
534551

535552
return new_states, dx
536553
end
554+
555+
function Base.show(io::IO, c::OptimiserChain)
556+
print(io, "OptimiserChain(")
557+
join(io, c.opts, ", ")
558+
print(io, ")")
559+
end

test/runtests.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,18 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
6565
o2 = OptimiserChain(ClipGrad(2), WeightDecay(0.1))
6666
@test Optimisers.update(Optimisers.setup(o2, x), x, dx)[2] [1-0.1-1, 10-1-2, 100-10-2]
6767

68-
o2r = OptimiserChain(WeightDecay(0.1), ClipGrad(2))
68+
o2n = OptimiserChain(OptimiserChain(ClipGrad(2), WeightDecay(0.1))) # nested
69+
@test Optimisers.update(Optimisers.setup(o2n, x), x, dx)[2] [1-0.1-1, 10-1-2, 100-10-2]
70+
71+
o2r = OptimiserChain(WeightDecay(0.1), ClipGrad(2)) # reversed
6972
@test Optimisers.update(Optimisers.setup(o2r, x), x, dx)[2] != [1-0.1-1, 10-2, 100-2]
73+
74+
# Trivial cases
75+
o1 = OptimiserChain(Descent(0.1))
76+
@test Optimisers.update(Optimisers.setup(o1, x), x, dx)[2] [0.9, 9.8, 99.7]
77+
78+
o0 = OptimiserChain()
79+
@test Optimisers.update(Optimisers.setup(o0, x), x, dx)[2] [1-1,10-2,100-3]
7080
end
7181

7282
@testset "trainable subset" begin

0 commit comments

Comments
 (0)