Skip to content

Commit fc1c6b9

Browse files
bors[bot]jeremiedb
andauthored
Merge #1521
1521: Fixes to Recurrent models for informative type mismatch error & output Vector for Vector input r=CarloLucibello a=jeremiedb Minor fix to Recurrent to return `Vector` with `Vector` input, returns an indicative error relative to type incompatibility where eltype of input doesn't match with eltype of state, as well as some typos in associated docs. As discussed in #1483. Co-authored-by: jeremie.db <[email protected]> Co-authored-by: jeremiedb <[email protected]>
2 parents 49c226d + 6aed1cf commit fc1c6b9

File tree

3 files changed

+51
-26
lines changed

3 files changed

+51
-26
lines changed

docs/src/models/recurrence.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ To introduce Flux's recurrence functionalities, we will consider the following v
88

99
In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step (could be a timestamp or a word in a sentence), and `y1` to `y3` are their respective outputs.
1010

11-
An aspect to recognize is that in such model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a dense layer for example is that the cell A is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
11+
An aspect to recognize is that in such model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
1212

1313
In the most basic RNN case, cell A could be defined by the following:
1414

@@ -69,15 +69,15 @@ Recur(RNNCell(2, 5, tanh))
6969

7070
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
7171

72-
Using these tools, we can now build the model is the above diagram with:
72+
Using these tools, we can now build the model shown in the above diagram with:
7373

7474
```julia
7575
m = Chain(RNN(2, 5), Dense(5, 1), x -> reshape(x, :))
7676
```
7777

7878
## Working with sequences
7979

80-
Using the previously defined `m` recurrent model, we can the apply it to a single step from our sequence:
80+
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:
8181

8282
```julia
8383
x = rand(Float32, 2)
@@ -86,7 +86,7 @@ julia> m(x)
8686
0.028398542
8787
```
8888

89-
The m(x) operation would be represented by `x1 -> A -> y1` in our diagram.
89+
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
9090
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2` since the model `m` has stored the state resulting from the `x1` step:
9191

9292
```julia
@@ -98,7 +98,7 @@ julia> m(x)
9898

9999
Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by broadcasting the model on a sequence of data.
100100

101-
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of length = `seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
101+
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
102102

103103
```julia
104104
x = [rand(Float32, 2) for i = 1:3]
@@ -170,4 +170,4 @@ function loss(x, y)
170170
end
171171
```
172172

173-
A potential source of ambiguity of RNN in Flux can come from the different data layout compared to some common frameworks where data is typically a 3 dimensional array: `(features, seq length, samples)`. In Flux, those 3 dimensions are provided through a vector of seq length containing a matrix `(features, samples)`.
173+
A potential source of ambiguity with RNN in Flux can come from the different data layout compared to some common frameworks where data is typically a 3 dimensional array: `(features, seq length, samples)`. In Flux, those 3 dimensions are provided through a vector of seq length containing a matrix `(features, samples)`.

src/layers/recurrent.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ mutable struct Recur{T,S}
3030
state::S
3131
end
3232

33-
function (m::Recur)(xs...)
34-
m.state, y = m.cell(m.state, xs...)
33+
function (m::Recur)(x)
34+
m.state, y = m.cell(m.state, x)
3535
return y
3636
end
3737

@@ -80,10 +80,11 @@ end
8080
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
8181
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
8282

83-
function (m::RNNCell)(h, x)
83+
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
8484
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
8585
h = σ.(Wi*x .+ Wh*h .+ b)
86-
return h, h
86+
sz = size(x)
87+
return h, reshape(h, :, sz[2:end]...)
8788
end
8889

8990
@functor RNNCell
@@ -133,7 +134,7 @@ function LSTMCell(in::Integer, out::Integer;
133134
return cell
134135
end
135136

136-
function (m::LSTMCell)((h, c), x)
137+
function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
137138
b, o = m.b, size(h, 1)
138139
g = m.Wi*x .+ m.Wh*h .+ b
139140
input = σ.(gate(g, o, 1))
@@ -142,7 +143,8 @@ function (m::LSTMCell)((h, c), x)
142143
output = σ.(gate(g, o, 4))
143144
c = forget .* c .+ input .* cell
144145
h′ = output .* tanh.(c)
145-
return (h′, c), h′
146+
sz = size(x)
147+
return (h′, c), reshape(h′, :, sz[2:end]...)
146148
end
147149

148150
@functor LSTMCell
@@ -160,8 +162,6 @@ See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
160162
for a good overview of the internals.
161163
"""
162164
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
163-
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
164-
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
165165
Recur(m::LSTMCell) = Recur(m, m.state0)
166166

167167
# TODO remove in v0.13
@@ -193,14 +193,15 @@ end
193193
GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
194194
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
195195

196-
function (m::GRUCell)(h, x)
196+
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
197197
b, o = m.b, size(h, 1)
198198
gx, gh = m.Wi*x, m.Wh*h
199199
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
200200
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
201201
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
202202
h′ = (1 .- z) .*.+ z .* h
203-
return h′, h′
203+
sz = size(x)
204+
return h′, reshape(h′, :, sz[2:end]...)
204205
end
205206

206207
@functor GRUCell

test/layers/recurrent.jl

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Ref FluxML/Flux.jl#1209 1D input
2-
@testset "BPTT" begin
2+
@testset "BPTT-1D" begin
33
seq = [rand(Float32, 2) for i = 1:3]
44
for r [RNN,]
5-
rnn = r(2,3)
5+
rnn = r(2, 3)
66
Flux.reset!(rnn)
77
grads_seq = gradient(Flux.params(rnn)) do
8-
sum(rnn.(seq)[3])
8+
sum(rnn.(seq)[3])
99
end
1010
Flux.reset!(rnn);
11-
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
11+
bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
1212
tanh.(rnn.cell.Wi * seq[2] + Wh *
1313
tanh.(rnn.cell.Wi * seq[1] +
1414
Wh * rnn.cell.state0
@@ -17,20 +17,20 @@
1717
+ rnn.cell.b)),
1818
rnn.cell.Wh)
1919
@test grads_seq[rnn.cell.Wh] bptt[1]
20-
end
20+
end
2121
end
2222

2323
# Ref FluxML/Flux.jl#1209 2D input
24-
@testset "BPTT" begin
25-
seq = [rand(Float32, (2,1)) for i = 1:3]
24+
@testset "BPTT-2D" begin
25+
seq = [rand(Float32, (2, 1)) for i = 1:3]
2626
for r [RNN,]
27-
rnn = r(2,3)
27+
rnn = r(2, 3)
2828
Flux.reset!(rnn)
2929
grads_seq = gradient(Flux.params(rnn)) do
30-
sum(rnn.(seq)[3])
30+
sum(rnn.(seq)[3])
3131
end
3232
Flux.reset!(rnn);
33-
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
33+
bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
3434
tanh.(rnn.cell.Wi * seq[2] + Wh *
3535
tanh.(rnn.cell.Wi * seq[1] +
3636
Wh * rnn.cell.state0
@@ -39,5 +39,29 @@ end
3939
+ rnn.cell.b)),
4040
rnn.cell.Wh)
4141
@test grads_seq[rnn.cell.Wh] bptt[1]
42+
end
43+
end
44+
45+
@testset "RNN-shapes" begin
46+
@testset for R in [RNN, GRU, LSTM]
47+
m1 = R(3, 5)
48+
m2 = R(3, 5)
49+
x1 = rand(Float32, 3)
50+
x2 = rand(Float32,3,1)
51+
Flux.reset!(m1)
52+
Flux.reset!(m2)
53+
@test size(m1(x1)) == (5,)
54+
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
55+
@test size(m2(x2)) == (5,1)
56+
@test size(m2(x2)) == (5,1)
57+
end
58+
end
59+
60+
@testset "RNN-input-state-eltypes" begin
61+
@testset for R in [RNN, GRU, LSTM]
62+
m = R(3, 5)
63+
x = rand(Float64, 3, 1)
64+
Flux.reset!(m)
65+
@test_throws MethodError m(x)
4266
end
4367
end

0 commit comments

Comments
 (0)