diff --git a/src/rules.jl b/src/rules.jl index ecc58609..1cda82d9 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -636,10 +636,12 @@ function _norm(dx::Broadcast.Broadcasted, p::Real) end """ - OptimiserChain(opts...) + OptimiserChain(o1, o2, o34...) + o1 >> o2 >> o3 -Compose a sequence of optimisers so that each `opt` in `opts` +Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)` updates the gradient, in the order specified. +May be entered using the `>>` operator with several `AbstractRule`s. With an empty sequence, `OptimiserChain()` is the identity, so `update!` will subtract the full gradient from the parameters. @@ -648,12 +650,13 @@ This is equivalent to `Descent(1)`. # Example ```jldoctest -julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1)); +julia> o = ClipGrad(1.0) >> Descent(0.1) +OptimiserChain(ClipGrad(1.0), Descent(0.1)) julia> m = (zeros(3),); julia> s = Optimisers.setup(o, m) -(Leaf(OptimiserChain(ClipGrad(1.0), Descent(0.1)), (nothing, nothing)),) +(Leaf(ClipGrad(1.0) => Descent(0.1), (nothing, nothing)),) julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting ([-0.03, -0.1, -0.1],) @@ -664,6 +667,11 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule end OptimiserChain(opts...) = OptimiserChain(opts) +@doc @doc(OptimiserChain) +Base.:(>>)(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b) +Base.:(>>)(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...) +Base.:(>>)(ab::OptimiserChain, c::AbstractRule) = OptimiserChain(ab.opts..., c) + @functor OptimiserChain init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts) @@ -679,7 +687,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...) end end -function Base.show(io::IO, c::OptimiserChain) +function Base.show(io::IO, c::OptimiserChain) # compact show + if length(c.opts) > 1 + join(io, c.opts, " >> ") + else + show(io, MIME"text/plain"(), c) + end +end +function Base.show(io::IO, ::MIME"text/plain", c::OptimiserChain) print(io, "OptimiserChain(") join(io, c.opts, ", ") print(io, ")")