Skip to content

Commit 0258de8

Browse files
committed
write => for OptimiserChain
1 parent 14949f1 commit 0258de8

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

src/rules.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -621,23 +621,26 @@ function apply!(o::ClipNorm, state, x, dx)
621621
end
622622

623623
"""
624-
OptimiserChain(opts...)
624+
OptimiserChain(o1, o2, o34...)
625+
o1 => o2 => o3
625626
626-
Compose a sequence of optimisers so that each `opt` in `opts`
627+
Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)`
627628
updates the gradient, in the order specified.
629+
May be entered using `Pair` syntax with several `AbstractRule`s.
628630
629631
With an empty sequence, `OptimiserChain()` is the identity,
630632
so `update!` will subtract the full gradient from the parameters.
631633
This is equivalent to `Descent(1)`.
632634
633635
# Example
634636
```jldoctest
635-
julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
637+
julia> o = ClipGrad(1.0) => Descent(0.1)
638+
OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
636639
637640
julia> m = (zeros(3),);
638641
639642
julia> s = Optimisers.setup(o, m)
640-
(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), (nothing, nothing)),)
643+
(Leaf(ClipGrad{Float64}(1.0) => Descent{Float64}(0.1), (nothing, nothing)),)
641644
642645
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
643646
([-0.03, -0.1, -0.1],)
@@ -648,6 +651,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
648651
end
649652
OptimiserChain(opts...) = OptimiserChain(opts)
650653

654+
Base.Pair(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b)
655+
Base.Pair(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...)
656+
651657
@functor OptimiserChain
652658

653659
init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts)
@@ -659,7 +665,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
659665
end
660666
end
661667

662-
function Base.show(io::IO, c::OptimiserChain)
668+
function Base.show(io::IO, c::OptimiserChain) # compact show
669+
if length(c.opts) > 1
670+
join(io, c.opts, " => ")
671+
else
672+
show(io, MIME"text/plain"(), c)
673+
end
674+
end
675+
function Base.show(io::IO, ::MIME"text/plain", c::OptimiserChain)
663676
print(io, "OptimiserChain(")
664677
join(io, c.opts, ", ")
665678
print(io, ")")

0 commit comments

Comments
 (0)