-
-
Notifications
You must be signed in to change notification settings - Fork 6
NewRecur experimental interface #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
c007097
c2a0ec9
49c601c
79e7261
b28bb57
7b60350
832f860
72a7fe1
b238091
2ed6588
1761614
c4d92b1
52f3b7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,4 +13,6 @@ include("chain.jl") | |
|
|
||
| include("compact.jl") | ||
|
|
||
| include("new_recur.jl") | ||
|
|
||
| end # module Fluxperimental | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,105 @@ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ##### Helper scan funtion which can likely be put into NNLib. ##### | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| scan | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Recreating jax.lax.scan functionality in julia. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) | ||||||||||||||||||||||||||||
| # xs = Flux.eachlastdim(x_block) | ||||||||||||||||||||||||||||
| x_init, x_rest = Iterators.peel(xs) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| (carry, out_) = func(init_carry, x_init) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| init = (typeof(out_)[out_], carry) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| function recurrence_op(input, (outputs, carry)) | ||||||||||||||||||||||||||||
| carry, out = func(carry, input) | ||||||||||||||||||||||||||||
| return vcat(outputs, typeof(out)[out]), carry | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
| results = foldr(recurrence_op, xs[(begin+1):end]; init) | ||||||||||||||||||||||||||||
| results[2], results[1] | ||||||||||||||||||||||||||||
mkschleg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| function scan_full(func, init_carry, x_block) | ||||||||||||||||||||||||||||
| xs_ = Flux.eachlastdim(x_block) | ||||||||||||||||||||||||||||
| xs = if xs_ isa Base.Generator | ||||||||||||||||||||||||||||
| collect(xs_) # eachlastdim produces a generator in non-gradient environment | ||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||
| xs_ | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
| scan_full(func, init_carry, xs) | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray}) | ||||||||||||||||||||||||||||
| x_init, x_rest = Iterators.peel(xs) | ||||||||||||||||||||||||||||
| (carry, y) = func(init_carry, x_init) | ||||||||||||||||||||||||||||
| for x in x_rest | ||||||||||||||||||||||||||||
| (carry, y) = func(carry, x) | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
| # carry, y | ||||||||||||||||||||||||||||
| carry, y | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| function scan_partial(func, init_carry, x_block) | ||||||||||||||||||||||||||||
| xs_ = Flux.eachlastdim(x_block) | ||||||||||||||||||||||||||||
| xs = if xs_ isa Base.Generator | ||||||||||||||||||||||||||||
| collect(xs_) # eachlastdim produces a generator in non-gradient environment | ||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||
| xs_ | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
| scan_partial(func, init_carry, xs) | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| NewRecur | ||||||||||||||||||||||||||||
| New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| struct NewRecur{RET_SEQUENCE, T} | ||||||||||||||||||||||||||||
| cell::T | ||||||||||||||||||||||||||||
| # state::S | ||||||||||||||||||||||||||||
| function NewRecur(cell; return_sequence::Bool=false) | ||||||||||||||||||||||||||||
| new{return_sequence, typeof(cell)}(cell) | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
| function NewRecur{true}(cell) | ||||||||||||||||||||||||||||
| new{true, typeof(cell)}(cell) | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
| function NewRecur{false}(cell) | ||||||||||||||||||||||||||||
| new{false, typeof(cell)}(cell) | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Flux.@functor NewRecur | ||||||||||||||||||||||||||||
| Flux.trainable(a::NewRecur) = (; cell = a.cell) | ||||||||||||||||||||||||||||
| Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")") | ||||||||||||||||||||||||||||
| NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| (l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur") | ||||||||||||||||||||||||||||
| (l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| (l::NewRecur)(xs) = l(l.cell.state0, xs) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
| (l::NewRecur)(xs) = l(l.cell.state0, xs) | |
| (l::NewRecur)(xs::AbstractArray) = l(l.cell.state0, xs) |
For your consideration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely this would be a good idea. But I didn't restrict this in this initial pass.
mkschleg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| function (l::NewRecur{true})(init_carry, | |
| xs,) | |
| results = scan_full(l.cell, init_carry, xs) | |
| h = results[2] | |
| sze = size(h[1]) | |
| reshape(reduce(hcat, h), sze[1], sze[2], length(h)) | |
| end | |
| function (l::NewRecur{true})(init_carry, xs) | |
| results = scan_full(l.cell, init_carry, xs) | |
| return results[1], stack(results[2]) | |
| end |
Similar story here. stack vs reduce(hcat are more or less equally efficient so feel free to use either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this to stack like you suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't realize stack is not available in 1.6. This function was added to 1.9. Are we targeting only 1.9?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's available via Compat.jl, which Flux already has as a transitive dep so there's zero additional import overhead.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
|
|
||
|
|
||
| @testset "NewRecur RNN" begin | ||
| @testset "Forward Pass" begin | ||
| cell = Flux.RNNCell(1, 1, identity) | ||
| layer = Fluxperimental.NewRecur(cell; return_sequence=true) | ||
| layer.cell.Wi .= 5.0 | ||
| layer.cell.Wh .= 4.0 | ||
| layer.cell.b .= 0.0f0 | ||
| layer.cell.state0 .= 7.0 | ||
| x = reshape([2.0f0, 3.0f0], 1, 1, 2) | ||
|
|
||
| # @show layer(x) | ||
| @test eltype(layer(x)) <: Float32 | ||
| @test size(layer(x)) == (1, 1, 2) | ||
|
|
||
| @test_throws MethodError layer([2.0f0]) | ||
| @test_throws MethodError layer([2.0f0;; 3.0f0]) | ||
| end | ||
|
|
||
|
|
||
| @testset "gradients-implicit" begin | ||
| cell = Flux.RNNCell(1, 1, identity) | ||
| layer = Flux.Recur(cell) | ||
| layer.cell.Wi .= 5.0 | ||
| layer.cell.Wh .= 4.0 | ||
| layer.cell.b .= 0.0f0 | ||
| layer.cell.state0 .= 7.0 | ||
| x = [[2.0f0], [3.0f0]] | ||
|
|
||
| # theoretical primal gradients | ||
| primal = | ||
| layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ | ||
| x[2] .* layer.cell.Wi | ||
| ∇Wi = x[1] .* layer.cell.Wh .+ x[2] | ||
| ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi | ||
| ∇b = layer.cell.Wh .+ 1 | ||
| ∇state0 = layer.cell.Wh .^ 2 | ||
|
|
||
| nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) | ||
| ps = Flux.params(nm_layer) | ||
| x_block = reshape(vcat(x...), 1, 1, length(x)) | ||
| e, g = Flux.withgradient(ps) do | ||
| out = nm_layer(x_block) | ||
| sum(out[1, 1, 2]) | ||
| end | ||
|
|
||
| @test primal[1] ≈ e | ||
| @test ∇Wi ≈ g[ps[1]] | ||
| @test ∇Wh ≈ g[ps[2]] | ||
| @test ∇b ≈ g[ps[3]] | ||
| @test ∇state0 ≈ g[ps[4]] | ||
| end | ||
|
|
||
|
|
||
| @testset "gradients-explicit" begin | ||
|
|
||
| cell = Flux.RNNCell(1, 1, identity) | ||
| layer = Flux.Recur(cell) | ||
| layer.cell.Wi .= 5.0 | ||
| layer.cell.Wh .= 4.0 | ||
| layer.cell.b .= 0.0f0 | ||
| layer.cell.state0 .= 7.0 | ||
| x = [[2.0f0], [3.0f0]] | ||
|
|
||
| # theoretical primal gradients | ||
| primal = | ||
| layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ | ||
| x[2] .* layer.cell.Wi | ||
| ∇Wi = x[1] .* layer.cell.Wh .+ x[2] | ||
| ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi | ||
| ∇b = layer.cell.Wh .+ 1 | ||
| ∇state0 = layer.cell.Wh .^ 2 | ||
|
|
||
|
|
||
| x_block = reshape(vcat(x...), 1, 1, length(x)) | ||
| nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) | ||
| e, g = Flux.withgradient(nm_layer) do layer | ||
| out = layer(x_block) | ||
| sum(out[1, 1, 2]) | ||
| end | ||
| grads = g[1][:cell] | ||
|
|
||
| @test primal[1] ≈ e | ||
| @test ∇Wi ≈ grads[:Wi] | ||
| @test ∇Wh ≈ grads[:Wh] | ||
| @test ∇b ≈ grads[:b] | ||
| @test ∇state0 ≈ grads[:state0] | ||
|
|
||
| end | ||
| end | ||
|
|
||
| @testset "New Recur RNN Partial Sequence" begin | ||
|
|
||
| @testset "Forward Pass" begin | ||
| cell = Flux.RNNCell(1, 1, identity) | ||
| layer = Fluxperimental.NewRecur(cell) | ||
| layer.cell.Wi .= 5.0 | ||
| layer.cell.Wh .= 4.0 | ||
| layer.cell.b .= 0.0f0 | ||
| layer.cell.state0 .= 7.0 | ||
| x = reshape([2.0f0, 3.0f0], 1, 1, 2) | ||
|
|
||
| @test eltype(layer(x)) <: Float32 | ||
| @test size(layer(x)) == (1, 1) | ||
|
|
||
| @test_throws MethodError layer([2.0f0]) | ||
| @test_throws MethodError layer([2.0f0;; 3.0f0]) | ||
|
|
||
| end | ||
|
|
||
| @testset "gradients-implicit" begin | ||
| cell = Flux.RNNCell(1, 1, identity) | ||
| layer = Flux.Recur(cell) | ||
| layer.cell.Wi .= 5.0 | ||
| layer.cell.Wh .= 4.0 | ||
| layer.cell.b .= 0.0f0 | ||
| layer.cell.state0 .= 7.0 | ||
| x = [[2.0f0], [3.0f0]] | ||
|
|
||
| # theoretical primal gradients | ||
| primal = | ||
| layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ | ||
| x[2] .* layer.cell.Wi | ||
| ∇Wi = x[1] .* layer.cell.Wh .+ x[2] | ||
| ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi | ||
| ∇b = layer.cell.Wh .+ 1 | ||
| ∇state0 = layer.cell.Wh .^ 2 | ||
|
|
||
| nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) | ||
| ps = Flux.params(nm_layer) | ||
| x_block = reshape(vcat(x...), 1, 1, length(x)) | ||
| e, g = Flux.withgradient(ps) do | ||
| out = (nm_layer)(x_block) | ||
| sum(out) | ||
| end | ||
|
|
||
| @test primal[1] ≈ e | ||
| @test ∇Wi ≈ g[ps[1]] | ||
| @test ∇Wh ≈ g[ps[2]] | ||
| @test ∇b ≈ g[ps[3]] | ||
| @test ∇state0 ≈ g[ps[4]] | ||
| end | ||
|
|
||
| @testset "gradients-explicit" begin | ||
|
|
||
|
|
||
| cell = Flux.RNNCell(1, 1, identity) | ||
| layer = Flux.Recur(cell) | ||
| layer.cell.Wi .= 5.0 | ||
| layer.cell.Wh .= 4.0 | ||
| layer.cell.b .= 0.0f0 | ||
| layer.cell.state0 .= 7.0 | ||
| x = [[2.0f0], [3.0f0]] | ||
|
|
||
| # theoretical primal gradients | ||
| primal = | ||
| layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ | ||
| x[2] .* layer.cell.Wi | ||
| ∇Wi = x[1] .* layer.cell.Wh .+ x[2] | ||
| ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi | ||
| ∇b = layer.cell.Wh .+ 1 | ||
| ∇state0 = layer.cell.Wh .^ 2 | ||
|
|
||
|
|
||
| x_block = reshape(vcat(x...), 1, 1, length(x)) | ||
| nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) | ||
| e, g = Flux.withgradient(nm_layer) do layer | ||
| out = layer(x_block) | ||
| sum(out) | ||
| end | ||
| grads = g[1][:cell] | ||
|
|
||
| @test primal[1] ≈ e | ||
| @test ∇Wi ≈ grads[:Wi] | ||
| @test ∇Wh ≈ grads[:Wh] | ||
| @test ∇b ≈ grads[:b] | ||
| @test ∇state0 ≈ grads[:state0] | ||
|
|
||
| end | ||
| end | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,4 +8,6 @@ using Flux, Fluxperimental | |
|
|
||
| include("compact.jl") | ||
|
|
||
| include("new_recur.jl") | ||
|
|
||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
x_restisn't used. Or ispeelmore efficient in the contexts we care about?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use x_rest in the foldl call to go through the rest of the sequence.
The reason peel was chosen here was not efficiency, but oddly gradient related. If instead I did
the resulting gradients were wrong. I've since documented this in the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems very concerning. Do you have a MWE?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, no mwe. Commenting out the peel and replacing with the code block, I haven't been able to figure out where the gradients are going wrong here. Only know the tests fail.