Skip to content

Commit 6041cf5

Browse files
Recurrence layer (#2549)
1 parent 009d35b commit 6041cf5

File tree

8 files changed

+146
-13
lines changed

8 files changed

+146
-13
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@ jobs:
7575
env:
7676
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
7777
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
78-
DATADEPS_ALWAYS_ACCEPT: true
78+

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
22
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
33
DataFrames, JLD2, MLDataDevices
44

5+
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
56

67
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
78

docs/src/guide/models/basics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.
185185
<h3><img src="../../../assets/zygote-crop.png" width="40px"/>&nbsp;<a href="https://github.com/FluxML/Zygote.jl">Zygote.jl</a></h3>
186186
```
187187

188-
Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
188+
Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
189189
Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
190190
hooks into Julia's compiler to find out what operations `f` contains, and transforms this
191191
to produce code for computing `∂f/∂x`.
@@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before?
372372
Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
373373
* Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
374374
- Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
375-
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
375+
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref Flux.glorot_uniform) by default.
376376
- It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
377377
- And it has some performance tricks: making sure element types match, and re-using some memory.
378378
* The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,

docs/src/guide/models/recurrence.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,50 @@ opt_state = Flux.setup(AdamW(1e-3), model)
166166
g = gradient(m -> Flux.mse(m(x), y), model)[1]
167167
Flux.update!(opt_state, model, g)
168168
```
169+
170+
Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:
171+
172+
```julia
173+
rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
174+
x = rand(Float32, 2, 4, 3)
175+
y = rnn(x)
176+
```
177+
178+
## Stacking recurrent layers
179+
180+
Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
181+
For instance, a model with two LSTM layers can be defined as follows:
182+
183+
```julia
184+
stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
185+
x = rand(Float32, 3, 4)
186+
y = stacked_rnn(x)
187+
```
188+
189+
If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows:
190+
191+
```julia
192+
struct StackedRNN{L,S}
193+
layers::L
194+
states0::S
195+
end
196+
197+
Flux.@layer StackedRNN
198+
199+
function StackedRNN(d::Int; num_layers::Int)
200+
layers = [LSTM(d => d) for _ in 1:num_layers]
201+
states0 = [Flux.initialstates(l) for l in layers]
202+
return StackedRNN(layers, states0)
203+
end
204+
205+
function (m::StackedRNN)(x)
206+
for (layer, state0) in zip(rnn.layers, rnn.states0)
207+
x = layer(x, state0)
208+
end
209+
return x
210+
end
211+
212+
rnn = StackedRNN(3; num_layers=2)
213+
x = rand(Float32, 3, 10)
214+
y = rnn(x)
215+
```

docs/src/reference/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ PairwiseFusion
104104
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
105105

106106
```@docs
107+
Recurrence
107108
RNNCell
108109
RNN
109110
LSTMCell

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
3838
export Chain, Dense, Embedding, EmbeddingBag,
3939
Maxout, SkipConnection, Parallel, PairwiseFusion,
4040
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
41-
RNN, LSTM, GRU, GRUv3,
41+
RNN, LSTM, GRU, GRUv3, Recurrence,
4242
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
4343
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
4444
Dropout, AlphaDropout,

src/layers/recurrent.jl

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
out_from_state(state) = state
22
out_from_state(state::Tuple) = state[1]
33

4-
function scan(cell, x, state0)
5-
state = state0
4+
function scan(cell, x, state)
65
y = []
76
for x_t in eachslice(x, dims = 2)
87
state = cell(x_t, state)
@@ -12,7 +11,67 @@ function scan(cell, x, state0)
1211
return stack(y, dims = 2)
1312
end
1413

14+
"""
15+
Recurrence(cell)
16+
17+
Create a recurrent layer that processes entire sequences out
18+
of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
19+
similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
20+
21+
The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
22+
a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
23+
the initial hidden state. The output of the `cell` is considered to be:
24+
1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
25+
2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
26+
27+
# Forward
28+
29+
rnn(x, [state])
30+
31+
The input `x` should be an array of size `in x len` or `in x len x batch_size`,
32+
where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
33+
The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
34+
`Flux.initialstates(cell)`.
35+
36+
The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
37+
38+
The operation performed is semantically equivalent to the following code:
39+
```julia
40+
out_from_state(state) = state
41+
out_from_state(state::Tuple) = state[1]
1542
43+
state = Flux.initialstates(cell)
44+
out = []
45+
for x_t in eachslice(x, dims = 2)
46+
state = cell(x_t, state)
47+
out = [out..., out_from_state(state)]
48+
end
49+
stack(out, dims = 2)
50+
```
51+
52+
# Examples
53+
54+
```jldoctest
55+
julia> rnn = Recurrence(RNNCell(2 => 3))
56+
Recurrence(
57+
RNNCell(2 => 3, tanh), # 18 parameters
58+
) # Total: 3 arrays, 18 parameters, 232 bytes.
59+
60+
julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
61+
62+
julia> y = rnn(x); # out x len x batch_size
63+
```
64+
"""
65+
struct Recurrence{M}
66+
cell::M
67+
end
68+
69+
@layer Recurrence
70+
71+
initialstates(rnn::Recurrence) = initialstates(rnn.cell)
72+
73+
(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
74+
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)
1675

1776
# Vanilla RNN
1877
@doc raw"""
@@ -185,9 +244,7 @@ julia> x = rand(Float32, (d_in, len, batch_size));
185244
julia> h = zeros(Float32, (d_out, batch_size));
186245
187246
julia> rnn = RNN(d_in => d_out)
188-
RNN(
189-
RNNCell(4 => 6, tanh), # 66 parameters
190-
) # Total: 3 arrays, 66 parameters, 424 bytes.
247+
RNN(4 => 6, tanh) # 66 parameters
191248
192249
julia> y = rnn(x, h); # [y] = [d_out, len, batch_size]
193250
```
@@ -212,7 +269,7 @@ struct RNN{M}
212269
cell::M
213270
end
214271

215-
@layer RNN
272+
@layer :noexpand RNN
216273

217274
initialstates(rnn::RNN) = initialstates(rnn.cell)
218275

@@ -233,6 +290,12 @@ function (m::RNN)(x::AbstractArray, h)
233290
return scan(m.cell, x, h)
234291
end
235292

293+
function Base.show(io::IO, m::RNN)
294+
print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
295+
print(io, ", ", m.cell.σ)
296+
print(io, ")")
297+
end
298+
236299

237300
# LSTM
238301
@doc raw"""
@@ -401,7 +464,7 @@ struct LSTM{M}
401464
cell::M
402465
end
403466

404-
@layer LSTM
467+
@layer :noexpand LSTM
405468

406469
initialstates(lstm::LSTM) = initialstates(lstm.cell)
407470

@@ -417,6 +480,10 @@ function (m::LSTM)(x::AbstractArray, state0)
417480
return scan(m.cell, x, state0)
418481
end
419482

483+
function Base.show(io::IO, m::LSTM)
484+
print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")
485+
end
486+
420487
# GRU
421488

422489
@doc raw"""
@@ -569,7 +636,7 @@ struct GRU{M}
569636
cell::M
570637
end
571638

572-
@layer GRU
639+
@layer :noexpand GRU
573640

574641
initialstates(gru::GRU) = initialstates(gru.cell)
575642

@@ -585,6 +652,10 @@ function (m::GRU)(x::AbstractArray, h)
585652
return scan(m.cell, x, h)
586653
end
587654

655+
function Base.show(io::IO, m::GRU)
656+
print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
657+
end
658+
588659
# GRU v3
589660
@doc raw"""
590661
GRUv3Cell(in => out; init_kernel = glorot_uniform,
@@ -729,7 +800,7 @@ struct GRUv3{M}
729800
cell::M
730801
end
731802

732-
@layer GRUv3
803+
@layer :noexpand GRUv3
733804

734805
initialstates(gru::GRUv3) = initialstates(gru.cell)
735806

@@ -744,3 +815,7 @@ function (m::GRUv3)(x::AbstractArray, h)
744815
@assert ndims(x) == 2 || ndims(x) == 3
745816
return scan(m.cell, x, h)
746817
end
818+
819+
function Base.show(io::IO, m::GRUv3)
820+
print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
821+
end

test/layers/recurrent.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,12 @@ end
305305
# no initial state same as zero initial state
306306
@test gru(x) gru(x, zeros(Float32, 4))
307307
end
308+
309+
@testset "Recurrence" begin
310+
x = rand(Float32, 2, 3, 4)
311+
for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
312+
cell = rnn.cell
313+
rec = Recurrence(cell)
314+
@test rec(x) rnn(x)
315+
end
316+
end

0 commit comments

Comments
 (0)