Skip to content

Conversation

@AzamatB
Copy link
Contributor

@AzamatB AzamatB commented Jan 11, 2020

The current definition of flip uses reverse, which implemented using mutation, so it is not compatible with Zygote.
The proposed version uses Zygote.Buffer instead. As a consequence, the proposed version avoids unnecessary allocations so is also more efficient

The current definition of `flip` uses `reverse`, which implemented using mutation, so it is not compatible with Zygote.
The proposed version instead uses `Zygote.Buffer` instead. As a consequence, the proposed version avoids unnecessary allocations so is more efficient
@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jan 11, 2020

This is really cool, thank you!
Wonder how that would interact with GPUs

@MikeInnes
Copy link
Member

Yeah this seems like an issue for GPU. Why not just add an adjoint for reverse?

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 14, 2020

@MikeInnes What would be the issue for GPU?
I mean it works with GPU:

using Flux
using CuArrays
import Zygote

function flip(f, xs)
   flipped_xs = Zygote.Buffer(xs)
   for t  Iterators.reverse(eachindex(xs))
      flipped_xs[t] = f(xs[t])
   end
   return copy(flipped_xs)
end

f = x -> 2x
xs = [rand(Float32, 2) for _  1:3]
xs = gpu(xs)
# this works
flip(f, xs)
# and so does this
xs = gpu(rand(Float32,5))
flip(f, xs)

@MikeInnes
Copy link
Member

Try with CuArrays.allowscalar(false). If you do scalar indexing on the GPU it's going to be really slow, even if it technically works.

@AzamatB
Copy link
Contributor Author

AzamatB commented Jan 15, 2020

This is only a problem if xs is a GPUArray{<:Number}, which I have never seen being used with flip. The primary use case for flip is in bidirectional RNN models, where xs is a Vector{<:GPUArray}. In such cases, this is not a problem since scalar indexing is still happening on CPU.

@MikeInnes
Copy link
Member

Even so, it seems like it'd be preferable to just add an adjoint for reverse, rather than writing out everything that might use it by hand.

@AzamatB
Copy link
Contributor Author

AzamatB commented Feb 3, 2020

I've refactored flip to use broadcasting. Now it is GPU friendly, while also being as efficient and does not use mutation.

@MikeInnes, @dhairyagandhi96 how about this version?

@MikeInnes
Copy link
Member

That's so much more complex than the original code. Why not just make reverse work?

bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Feb 25, 2020
515: Fix Flux.flip by providing an adjoint for Base.reverse r=dhairyagandhi96 a=tanhevg

The main motivation behind this PR is to address various issues concerning `Flux.flip()` (used mainly for bRNNs), e.g.  FluxML/Flux.jl#962, FluxML/Flux.jl#990 and FluxML/model-zoo#179

Co-authored-by: Evgeny Tankhilevich <[email protected]>
@CarloLucibello
Copy link
Member

Now that FluxML/Zygote.jl#515 is merged we should just a test for flip in Flux to avoid regressions. @AzamatB would you like to turn this PR into just adding some tests?

@ToucheSir
Copy link
Member

@AzamatB do you still want to pursue adding a regression test for flip as Carlo mentioned? Otherwise it would be better to close this and start a new PR for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants