Skip to content

Commit 4a55975

Browse files
committed
change from => to >>
1 parent 99ae33d commit 4a55975

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/rules.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -637,11 +637,11 @@ end
637637

638638
"""
639639
OptimiserChain(o1, o2, o34...)
640-
o1 => o2 => o3
640+
o1 >> o2 >> o3
641641
642642
Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)`
643643
updates the gradient, in the order specified.
644-
May be entered using `Pair` syntax with several `AbstractRule`s.
644+
May be entered using the `>>` operator with several `AbstractRule`s.
645645
646646
With an empty sequence, `OptimiserChain()` is the identity,
647647
so `update!` will subtract the full gradient from the parameters.
@@ -650,8 +650,8 @@ This is equivalent to `Descent(1)`.
650650
# Example
651651
652652
```jldoctest
653-
julia> o = ClipGrad(1.0) => Descent(0.1)
654-
OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
653+
julia> o = ClipGrad(1.0) >> Descent(0.1)
654+
OptimiserChain(ClipGrad(1.0), Descent(0.1))
655655
656656
julia> m = (zeros(3),);
657657
@@ -667,8 +667,10 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
667667
end
668668
OptimiserChain(opts...) = OptimiserChain(opts)
669669

670-
Base.Pair(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b)
671-
Base.Pair(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...)
670+
@doc @doc(OptimiserChain)
671+
Base.:(>>)(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b)
672+
Base.:(>>)(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...)
673+
Base.:(>>)(ab::OptimiserChain, c::AbstractRule) = OptimiserChain(ab.opts..., c)
672674

673675
@functor OptimiserChain
674676

@@ -687,7 +689,7 @@ end
687689

688690
function Base.show(io::IO, c::OptimiserChain) # compact show
689691
if length(c.opts) > 1
690-
join(io, c.opts, " => ")
692+
join(io, c.opts, " >> ")
691693
else
692694
show(io, MIME"text/plain"(), c)
693695
end

0 commit comments

Comments
 (0)