diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 91438b0..04d3aad 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -11,6 +11,10 @@ export shinkansen! include("chain.jl") + +include("recur.jl") + + include("compact.jl") end # module Fluxperimental diff --git a/src/chain.jl b/src/chain.jl index 0095de3..2df573a 100644 --- a/src/chain.jl +++ b/src/chain.jl @@ -10,6 +10,11 @@ function apply(chain::Flux.Chain, x) Flux.Chain(layers), out end +function apply(chain::Flux.Chain, x::Union{AbstractVector{<:AbstractArray}, Base.Generator}) + layers, out = _apply(chain.layers, x) + Flux.Chain(layers), out +end + function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS} layers, out = _apply(Tuple(layers), x) NamedTuple{NMS}(layers), out @@ -18,7 +23,7 @@ end function _scan(layers::AbstractVector, x) new_layers = typeof(layers)(undef, length(layers)) for (idx, f) in enumerate(layers) - new_layers[idx], x = _apply(f, x) + new_layers[idx], x = apply(f, x) end new_layers, x end @@ -27,7 +32,7 @@ end # example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x) duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer - out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input) + out, back = ChainRulesCore.rrule_via_ad(cfg, apply, cur_layer, input) end outs = map(first, duo) backs = map(last, duo) @@ -52,11 +57,11 @@ end @generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N} x_symbols = vcat(:x, [gensym() for _ in 1:N]) l_symbols = [gensym() for _ in 1:N] - calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N] + calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = apply(layers[$i], $(x_symbols[i]))) for i in 1:N] push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end]))) Expr(:block, calls...) end -_apply(layer, x) = layer, layer(x) +apply(layer, x) = layer, layer(x) diff --git a/src/recur.jl b/src/recur.jl new file mode 100644 index 0000000..6709a11 --- /dev/null +++ b/src/recur.jl @@ -0,0 +1,97 @@ +""" + NM_Recur +Non-mutating Recur. An experimental recur interface for the new chain api. +""" +struct NM_Recur{RET_SEQUENCE, T, S} + cell::T + state::S + function NM_Recur(cell, state; return_sequence::Bool=false) + new{return_sequence, typeof(cell), typeof(state)}(cell, state) + end + function NM_Recur{true}(cell, state) + new{true, typeof(cell), typeof(state)}(cell, state) + end + function NM_Recur{false}(cell, state) + new{false, typeof(cell), typeof(state)}(cell, state) + end +end + +function apply(m::NM_Recur, x) + state, y = m.cell(m.state, x) + return NM_Recur(m.cell, state), y +end + +# This is the same way we do 3-tensers from Flux.Recur +function apply(m::NM_Recur{true}, x::AbstractArray{T, 3}) where T + # h = [m(x_t) for x_t in eachlastdim(x)] + l, h = apply(m, Flux.eachlastdim(x)) + sze = size(h[1]) + l, reshape(reduce(hcat, h), sze[1], sze[2], length(h)) +end + +function apply(m::NM_Recur{false}, x::AbstractArray{T, 3}) where T + apply(m, Flux.eachlastdim(x)) +end + +function apply(l::NM_Recur{false}, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}) + rnn = l.cell + # carry = layer.stamte + x_init, x_rest = Iterators.peel(xs) + (carry, y) = rnn(l.state, x_init) + for x in x_rest + (carry, y) = rnn(carry, x) + end + NM_Recur{false}(rnn, carry), y +end + +# From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/ +function apply(l::NM_Recur{true}, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}) + rnn = l.cell + _xs = if xs isa Base.Generator + collect(xs) # TODO: Fix. I can't figure out how to get around this for generators. + else + xs + end + x_init, _ = Iterators.peel(_xs) + + (carry, out_) = rnn(l.state, x_init) + + init = (typeof(out_)[out_], carry) + + function recurrence_op(input, (outputs, carry)) + carry, out = rnn(carry, input) + return vcat(outputs, typeof(out)[out]), carry + end + results = foldr(recurrence_op, _xs[(begin+1):end]; init) + return NM_Recur{true}(rnn, results[1][end]), first(results) +end + +Flux.@functor NM_Recur +Flux.trainable(a::NM_Recur) = (; cell = a.cell) + +Base.show(io::IO, m::NM_Recur) = print(io, "Recur(", m.cell, ")") + +NM_RNN(a...; return_sequence::Bool=false, ka...) = NM_Recur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) +NM_Recur(m::Flux.RNNCell; return_sequence::Bool=false) = NM_Recur(m, m.state0; return_sequence=return_sequence) + +# Quick Reset functionality + +struct RecurWalk <: Flux.Functors.AbstractWalk end +(::RecurWalk)(recurse, x) = x isa Fluxperimental.NM_Recur ? reset(x) : Flux.Functors.DefaultWalk()(recurse, x) + +function reset(m::NM_Recur{SEQ}) where SEQ + NM_Recur{SEQ}(m.cell, m.cell.state0) +end +reset(m) = m +function reset(m::Flux.Chain) + ret = Flux.Functors.fmap((l)->l, m; walk=RecurWalk()) +end + + +## +# Fallback apply timeseries data to other layers. Likely needs to be thoought through a bit more. +## + +function apply(l, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}) + l, [l(x) for x in xs] +end diff --git a/test/recur.jl b/test/recur.jl new file mode 100644 index 0000000..3d1f47f --- /dev/null +++ b/test/recur.jl @@ -0,0 +1,109 @@ + +@testset "RNN gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NM_Recur(cell; return_sequence = true) + ps = Flux.params(nm_layer) + e, g = Flux.withgradient(ps) do + l, out = Fluxperimental.apply(nm_layer, x) + sum(out[2]) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] +end + +@testset "RNN gradients-implicit-partial sequence" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NM_Recur(cell; return_sequence = false) + ps = Flux.params(nm_layer) + e, g = Flux.withgradient(ps) do + l, out = Fluxperimental.apply(nm_layer, x) + sum(out) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] +end + +@testset "RNN gradients-explicit partial sequence" begin + + + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + + + nm_layer = Fluxperimental.NM_Recur(cell; return_sequence = false) + e, g = Flux.withgradient(nm_layer) do layer + r_l = Fluxperimental.reset(layer) + l, out = Fluxperimental.apply(r_l, x) + sum(out) + end + grads = g[1][:cell] + + @test primal[1] ≈ e + + if VERSION < v"1.7" + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + else + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 55315cc..fe327ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,8 @@ using Flux, Fluxperimental include("chain.jl") + include("recur.jl") + include("compact.jl") end