@@ -637,11 +637,11 @@ end
637
637
638
638
"""
639
639
OptimiserChain(o1, o2, o34...)
640
- o1 => o2 = > o3
640
+ o1 >> o2 > > o3
641
641
642
642
Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)`
643
643
updates the gradient, in the order specified.
644
- May be entered using `Pair` syntax with several `AbstractRule`s.
644
+ May be entered using the `>>` operator with several `AbstractRule`s.
645
645
646
646
With an empty sequence, `OptimiserChain()` is the identity,
647
647
so `update!` will subtract the full gradient from the parameters.
@@ -650,8 +650,8 @@ This is equivalent to `Descent(1)`.
650
650
# Example
651
651
652
652
```jldoctest
653
- julia> o = ClipGrad(1.0) = > Descent(0.1)
654
- OptimiserChain(ClipGrad{Float64} (1.0), Descent{Float64} (0.1))
653
+ julia> o = ClipGrad(1.0) > > Descent(0.1)
654
+ OptimiserChain(ClipGrad(1.0), Descent(0.1))
655
655
656
656
julia> m = (zeros(3),);
657
657
@@ -667,8 +667,10 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
667
667
end
668
668
OptimiserChain (opts... ) = OptimiserChain (opts)
669
669
670
- Base. Pair (a:: AbstractRule , b:: AbstractRule ) = OptimiserChain (a, b)
671
- Base. Pair (a:: AbstractRule , bc:: OptimiserChain ) = OptimiserChain (a, bc. opts... )
670
+ @doc @doc (OptimiserChain)
671
+ Base.:(>> )(a:: AbstractRule , b:: AbstractRule ) = OptimiserChain (a, b)
672
+ Base.:(>> )(a:: AbstractRule , bc:: OptimiserChain ) = OptimiserChain (a, bc. opts... )
673
+ Base.:(>> )(ab:: OptimiserChain , c:: AbstractRule ) = OptimiserChain (ab. opts... , c)
672
674
673
675
@functor OptimiserChain
674
676
687
689
688
690
function Base. show (io:: IO , c:: OptimiserChain ) # compact show
689
691
if length (c. opts) > 1
690
- join (io, c. opts, " = > " )
692
+ join (io, c. opts, " > > " )
691
693
else
692
694
show (io, MIME " text/plain" (), c)
693
695
end
0 commit comments