Skip to content

Commit 99ae33d

Browse files
committed
write => for OptimiserChain
1 parent 1cd1e87 commit 99ae33d

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

src/rules.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,12 @@ function _norm(dx::Broadcast.Broadcasted, p::Real)
636636
end
637637

638638
"""
639-
OptimiserChain(opts...)
639+
OptimiserChain(o1, o2, o34...)
640+
o1 => o2 => o3
640641
641-
Compose a sequence of optimisers so that each `opt` in `opts`
642+
Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)`
642643
updates the gradient, in the order specified.
644+
May be entered using `Pair` syntax with several `AbstractRule`s.
643645
644646
With an empty sequence, `OptimiserChain()` is the identity,
645647
so `update!` will subtract the full gradient from the parameters.
@@ -648,12 +650,13 @@ This is equivalent to `Descent(1)`.
648650
# Example
649651
650652
```jldoctest
651-
julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
653+
julia> o = ClipGrad(1.0) => Descent(0.1)
654+
OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
652655
653656
julia> m = (zeros(3),);
654657
655658
julia> s = Optimisers.setup(o, m)
656-
(Leaf(OptimiserChain(ClipGrad(1.0), Descent(0.1)), (nothing, nothing)),)
659+
(Leaf(ClipGrad(1.0) => Descent(0.1), (nothing, nothing)),)
657660
658661
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
659662
([-0.03, -0.1, -0.1],)
@@ -664,6 +667,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
664667
end
665668
OptimiserChain(opts...) = OptimiserChain(opts)
666669

670+
Base.Pair(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b)
671+
Base.Pair(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...)
672+
667673
@functor OptimiserChain
668674

669675
init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts)
@@ -679,7 +685,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
679685
end
680686
end
681687

682-
function Base.show(io::IO, c::OptimiserChain)
688+
function Base.show(io::IO, c::OptimiserChain) # compact show
689+
if length(c.opts) > 1
690+
join(io, c.opts, " => ")
691+
else
692+
show(io, MIME"text/plain"(), c)
693+
end
694+
end
695+
function Base.show(io::IO, ::MIME"text/plain", c::OptimiserChain)
683696
print(io, "OptimiserChain(")
684697
join(io, c.opts, ", ")
685698
print(io, ")")

0 commit comments

Comments
 (0)