diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 499a21ab0f..03ae519a21 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -55,7 +55,10 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to reset!(m::Recur) = (m.state = m.init) reset!(m) = foreach(reset!, functor(m)[1]) -flip(f, xs) = reverse(f.(reverse(xs))) +function flip(f, xs) + rev_time = Iterators.reverse(eachindex(xs)) + return getindex.(Ref(f.(getindex.(Ref(xs), rev_time))), rev_time) +end # Vanilla RNN