Skip to content

Commit 0b2d32b

Browse files
authored
Mark OptimiserChain as @functor and improve inference for apply! (#115)
* mark `OptimiserChain` with `@functor`, and improve type inference for `apply!(o::OptimiserChain, ...)` * fix doc tests * fix more doctests Co-authored-by: Jonathan Doucette <[email protected]>
1 parent 444a6b9 commit 0b2d32b

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/adjust.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ of the state `tree`.
6868
To change just the learning rate, provide a number `η::Real`.
6969
7070
# Example
71-
```jldoctest
71+
```jldoctest adjust
7272
julia> m = (vec = rand(Float32, 2), fun = sin);
7373
7474
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
@@ -88,18 +88,18 @@ julia> st
8888
To change other parameters, `adjust!` also accepts keyword arguments matching the field
8989
names of the optimisation rule's type.
9090
91-
```
91+
```jldoctest adjust
9292
julia> fieldnames(Adam)
9393
(:eta, :beta, :epsilon)
9494
9595
julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
96-
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing)
96+
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
9797
9898
julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
99-
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing)
99+
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
100100
101101
julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
102-
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
102+
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
103103
```
104104
"""
105105
adjust!(tree, eta::Real) = foreach(st -> adjust!(st, eta), tree)

src/rules.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
607607
julia> m = (zeros(3),);
608608
609609
julia> s = Optimisers.setup(o, m)
610-
(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), [nothing, nothing]),)
610+
(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), (nothing, nothing)),)
611611
612612
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
613613
([-0.03, -0.1, -0.1],)
@@ -618,15 +618,15 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
618618
end
619619
OptimiserChain(opts...) = OptimiserChain(opts)
620620

621-
init(o::OptimiserChain, x::AbstractArray) = [init(opt, x) for opt in o.opts]
621+
@functor OptimiserChain
622+
623+
init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts)
622624

623625
function apply!(o::OptimiserChain, states, x, dx, dxs...)
624-
new_states = similar(states)
625-
for (i, (opt, state)) in enumerate(zip(o.opts, states))
626-
new_states[i], dx = apply!(opt, state, x, dx, dxs...)
626+
foldl(tuple.(o.opts, states); init = ((), dx)) do (states′, dx′), (opt, state)
627+
state′, dx′ = apply!(opt, state, x, dx′, dxs...)
628+
return (states′..., state′), dx′
627629
end
628-
629-
return new_states, dx
630630
end
631631

632632
function Base.show(io::IO, c::OptimiserChain)

0 commit comments

Comments
 (0)