@@ -621,23 +621,26 @@ function apply!(o::ClipNorm, state, x, dx)
621
621
end
622
622
623
623
"""
624
- OptimiserChain(opts...)
624
+ OptimiserChain(o1, o2, o34...)
625
+ o1 => o2 => o3
625
626
626
- Compose a sequence of optimisers so that each `opt` in `opts `
627
+ Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...) `
627
628
updates the gradient, in the order specified.
629
+ May be entered using `Pair` syntax with several `AbstractRule`s.
628
630
629
631
With an empty sequence, `OptimiserChain()` is the identity,
630
632
so `update!` will subtract the full gradient from the parameters.
631
633
This is equivalent to `Descent(1)`.
632
634
633
635
# Example
634
636
```jldoctest
635
- julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
637
+ julia> o = ClipGrad(1.0) => Descent(0.1)
638
+ OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
636
639
637
640
julia> m = (zeros(3),);
638
641
639
642
julia> s = Optimisers.setup(o, m)
640
- (Leaf(OptimiserChain( ClipGrad{Float64}(1.0), Descent{Float64}(0.1) ), (nothing, nothing)),)
643
+ (Leaf(ClipGrad{Float64}(1.0) => Descent{Float64}(0.1), (nothing, nothing)),)
641
644
642
645
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
643
646
([-0.03, -0.1, -0.1],)
@@ -648,6 +651,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
648
651
end
649
652
OptimiserChain (opts... ) = OptimiserChain (opts)
650
653
654
+ Base. Pair (a:: AbstractRule , b:: AbstractRule ) = OptimiserChain (a, b)
655
+ Base. Pair (a:: AbstractRule , bc:: OptimiserChain ) = OptimiserChain (a, bc. opts... )
656
+
651
657
@functor OptimiserChain
652
658
653
659
init (o:: OptimiserChain , x:: AbstractArray ) = map (opt -> init (opt, x), o. opts)
@@ -659,7 +665,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
659
665
end
660
666
end
661
667
662
- function Base. show (io:: IO , c:: OptimiserChain )
668
+ function Base. show (io:: IO , c:: OptimiserChain ) # compact show
669
+ if length (c. opts) > 1
670
+ join (io, c. opts, " => " )
671
+ else
672
+ show (io, MIME " text/plain" (), c)
673
+ end
674
+ end
675
+ function Base. show (io:: IO , :: MIME"text/plain" , c:: OptimiserChain )
663
676
print (io, " OptimiserChain(" )
664
677
join (io, c. opts, " , " )
665
678
print (io, " )" )
0 commit comments