Skip to content
Merged
2 changes: 2 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ include("chain.jl")

include("compact.jl")

include("new_recur.jl")

end # module Fluxperimental
83 changes: 83 additions & 0 deletions src/new_recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@


"""
NewRecur
New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux.
"""
struct NewRecur{RET_SEQUENCE, T}
cell::T
# state::S
function NewRecur(cell; return_sequence::Bool=false)
new{return_sequence, typeof(cell)}(cell)
end
function NewRecur{true}(cell)
new{true, typeof(cell)}(cell)
end
function NewRecur{false}(cell)
new{false, typeof(cell)}(cell)
end
end

# This is the same way we do 3-tensers from Flux.Recur
function (m::NewRecur{false})(x::AbstractArray{T, N}, carry) where {T, N}
@assert N >= 3
# h = [m(x_t) for x_t in eachlastdim(x)]

cell = l.cell
x_init, x_rest = Iterators.peel(xs)
(carry, y) = cell(carry, x_init)
for x in x_rest
(carry, y) = cell(carry, x)
end
# carry, y
y

end

function (l::NewRecur{false})(x::AbstractArray{T, 3}, carry=l.cell.state0) where T
m(Flux.eachlastdim(x), carry)
end

function (l::NewRecur{false})(xs::Union{AbstractVector{<:AbstractArray}, Base.Generator},
carry=l.cell.state0)
rnn = l.cell
# carry = layer.stamte
x_init, x_rest = Iterators.peel(xs)
(carry, y) = rnn(carry, x_init)
for x in x_rest
(carry, y) = rnn(carry, x)
end
y
end

# From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/
function (l::NewRecur{true})(xs::Union{AbstractVector{<:AbstractArray}, Base.Generator},
carry=l.cell.state0)
rnn = l.cell
_xs = if xs isa Base.Generator
collect(xs) # TODO: Fix. I can't figure out how to get around this for generators.
else
xs
end
x_init, _ = Iterators.peel(_xs)

(carry, out_) = rnn(carry, x_init)

init = (typeof(out_)[out_], carry)

function recurrence_op(input, (outputs, carry))
carry, out = rnn(carry, input)
return vcat(outputs, typeof(out)[out]), carry
end
results = foldr(recurrence_op, _xs[(begin+1):end]; init)
# return NewRecur{true}(rnn, results[1][end]), first(results)
first(results)
end

Flux.@functor NewRecur
Flux.trainable(a::NewRecur) = (; cell = a.cell)

Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")")

NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence)

111 changes: 111 additions & 0 deletions test/new_recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@


@testset "RNN gradients-implicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true)
ps = Flux.params(nm_layer)
e, g = Flux.withgradient(ps) do
out = nm_layer(x)
sum(out[2])
end

@test primal[1] ≈ e
@test ∇Wi ≈ g[ps[1]]
@test ∇Wh ≈ g[ps[2]]
@test ∇b ≈ g[ps[3]]
@test ∇state0 ≈ g[ps[4]]
end

@testset "RNN gradients-implicit-partial sequence" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
ps = Flux.params(nm_layer)
e, g = Flux.withgradient(ps) do
out = (nm_layer)(x)
sum(out)
end

@test primal[1] ≈ e
@test ∇Wi ≈ g[ps[1]]
@test ∇Wh ≈ g[ps[2]]
@test ∇b ≈ g[ps[3]]
@test ∇state0 ≈ g[ps[4]]
end

@testset "RNN gradients-explicit partial sequence" begin


cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2



nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
e, g = Flux.withgradient(nm_layer) do layer
# r_l = Fluxperimental.reset(layer)
out = layer(x)
sum(out)
end
grads = g[1][:cell]

@test primal[1] ≈ e

if VERSION < v"1.7"
@test ∇Wi ≈ grads[:Wi]
@test ∇Wh ≈ grads[:Wh]
@test ∇b ≈ grads[:b]
@test ∇state0 ≈ grads[:state0]
else
@test ∇Wi ≈ grads[:Wi]
@test ∇Wh ≈ grads[:Wh]
@test ∇b ≈ grads[:b]
@test ∇state0 ≈ grads[:state0]
end
end

2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ using Flux, Fluxperimental

include("compact.jl")

include("new_recur.jl")

end