@@ -636,10 +636,12 @@ function _norm(dx::Broadcast.Broadcasted, p::Real)
636
636
end
637
637
638
638
"""
639
- OptimiserChain(opts...)
639
+ OptimiserChain(o1, o2, o34...)
640
+ o1 => o2 => o3
640
641
641
- Compose a sequence of optimisers so that each `opt` in `opts `
642
+ Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...) `
642
643
updates the gradient, in the order specified.
644
+ May be entered using `Pair` syntax with several `AbstractRule`s.
643
645
644
646
With an empty sequence, `OptimiserChain()` is the identity,
645
647
so `update!` will subtract the full gradient from the parameters.
@@ -648,12 +650,13 @@ This is equivalent to `Descent(1)`.
648
650
# Example
649
651
650
652
```jldoctest
651
- julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
653
+ julia> o = ClipGrad(1.0) => Descent(0.1)
654
+ OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
652
655
653
656
julia> m = (zeros(3),);
654
657
655
658
julia> s = Optimisers.setup(o, m)
656
- (Leaf(OptimiserChain( ClipGrad(1.0), Descent(0.1) ), (nothing, nothing)),)
659
+ (Leaf(ClipGrad(1.0) => Descent(0.1), (nothing, nothing)),)
657
660
658
661
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
659
662
([-0.03, -0.1, -0.1],)
@@ -664,6 +667,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
664
667
end
665
668
OptimiserChain (opts... ) = OptimiserChain (opts)
666
669
670
+ Base. Pair (a:: AbstractRule , b:: AbstractRule ) = OptimiserChain (a, b)
671
+ Base. Pair (a:: AbstractRule , bc:: OptimiserChain ) = OptimiserChain (a, bc. opts... )
672
+
667
673
@functor OptimiserChain
668
674
669
675
init (o:: OptimiserChain , x:: AbstractArray ) = map (opt -> init (opt, x), o. opts)
@@ -679,7 +685,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
679
685
end
680
686
end
681
687
682
- function Base. show (io:: IO , c:: OptimiserChain )
688
+ function Base. show (io:: IO , c:: OptimiserChain ) # compact show
689
+ if length (c. opts) > 1
690
+ join (io, c. opts, " => " )
691
+ else
692
+ show (io, MIME " text/plain" (), c)
693
+ end
694
+ end
695
+ function Base. show (io:: IO , :: MIME"text/plain" , c:: OptimiserChain )
683
696
print (io, " OptimiserChain(" )
684
697
join (io, c. opts, " , " )
685
698
print (io, " )" )
0 commit comments