-
-
Couldn't load subscription status.
- Fork 615
Make flip Zygote compatible #990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The current definition of `flip` uses `reverse`, which implemented using mutation, so it is not compatible with Zygote. The proposed version instead uses `Zygote.Buffer` instead. As a consequence, the proposed version avoids unnecessary allocations so is more efficient
|
This is really cool, thank you! |
|
Yeah this seems like an issue for GPU. Why not just add an adjoint for |
|
@MikeInnes What would be the issue for GPU? using Flux
using CuArrays
import Zygote
function flip(f, xs)
flipped_xs = Zygote.Buffer(xs)
for t ∈ Iterators.reverse(eachindex(xs))
flipped_xs[t] = f(xs[t])
end
return copy(flipped_xs)
end
f = x -> 2x
xs = [rand(Float32, 2) for _ ∈ 1:3]
xs = gpu(xs)
# this works
flip(f, xs)
# and so does this
xs = gpu(rand(Float32,5))
flip(f, xs) |
|
Try with |
|
This is only a problem if |
|
Even so, it seems like it'd be preferable to just add an adjoint for |
|
I've refactored flip to use broadcasting. Now it is GPU friendly, while also being as efficient and does not use mutation. @MikeInnes, @dhairyagandhi96 how about this version? |
|
That's so much more complex than the original code. Why not just make |
515: Fix Flux.flip by providing an adjoint for Base.reverse r=dhairyagandhi96 a=tanhevg The main motivation behind this PR is to address various issues concerning `Flux.flip()` (used mainly for bRNNs), e.g. FluxML/Flux.jl#962, FluxML/Flux.jl#990 and FluxML/model-zoo#179 Co-authored-by: Evgeny Tankhilevich <[email protected]>
|
Now that FluxML/Zygote.jl#515 is merged we should just a test for |
|
@AzamatB do you still want to pursue adding a regression test for |
The current definition of
flipusesreverse, which implemented using mutation, so it is not compatible with Zygote.The proposed version uses
Zygote.Bufferinstead. As a consequence, the proposed version avoids unnecessary allocations so is also more efficient