From bff562fa479c81ae3e4543e7a2025630c51bf251 Mon Sep 17 00:00:00 2001 From: AzamatB Date: Sat, 11 Jan 2020 13:49:56 +0600 Subject: [PATCH 1/3] Make flip Zygote compatible The current definition of `flip` uses `reverse`, which implemented using mutation, so it is not compatible with Zygote. The proposed version instead uses `Zygote.Buffer` instead. As a consequence, the proposed version avoids unnecessary allocations so is more efficient --- src/layers/recurrent.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 499a21ab0f..cf20426ecf 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -55,7 +55,13 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to reset!(m::Recur) = (m.state = m.init) reset!(m) = foreach(reset!, functor(m)[1]) -flip(f, xs) = reverse(f.(reverse(xs))) +function flip(f, xs) + flipped_xs = Zygote.Buffer(xs) + for t ∈ Iterators.reverse(eachindex(xs)) + flipped_xs[t] = f(xs[t]) + end + return copy(flipped_xs) +end # Vanilla RNN From 7173070f6e9d7cc186ee3f15eac15d5f18a21d6b Mon Sep 17 00:00:00 2001 From: AzamatB Date: Mon, 3 Feb 2020 18:42:03 +0600 Subject: [PATCH 2/3] refactor flip --- src/layers/recurrent.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index cf20426ecf..afca60a55a 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -56,11 +56,8 @@ reset!(m::Recur) = (m.state = m.init) reset!(m) = foreach(reset!, functor(m)[1]) function flip(f, xs) - flipped_xs = Zygote.Buffer(xs) - for t ∈ Iterators.reverse(eachindex(xs)) - flipped_xs[t] = f(xs[t]) - end - return copy(flipped_xs) + rev_time = reverse(eachindex(xs)) + return getindex.(Ref(f.(getindex.(Ref(xs), rev_time))), rev_time) end # Vanilla RNN From 789d6dbf98bad770e129e238ab5bc74ebc4f1535 Mon Sep 17 00:00:00 2001 From: AzamatB Date: Mon, 3 Feb 2020 18:49:39 +0600 Subject: [PATCH 3/3] minor fix --- src/layers/recurrent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index afca60a55a..03ae519a21 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -56,7 +56,7 @@ reset!(m::Recur) = (m.state = m.init) reset!(m) = foreach(reset!, functor(m)[1]) function flip(f, xs) - rev_time = reverse(eachindex(xs)) + rev_time = Iterators.reverse(eachindex(xs)) return getindex.(Ref(f.(getindex.(Ref(xs), rev_time))), rev_time) end