Skip to content

Commit 2a0ed9b

Browse files
committed
Add doctests in recurrent.jl
1 parent 8b189d0 commit 2a0ed9b

File tree

2 files changed

+117
-19
lines changed

2 files changed

+117
-19
lines changed

docs/src/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Much like the core layers above, but can be used to process sequence data (as we
4242
RNN
4343
LSTM
4444
GRU
45+
GRUv3
4546
Flux.Recur
4647
Flux.reset!
4748
```

src/layers/recurrent.jl

Lines changed: 116 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,28 +63,97 @@ in the background. `cell` should be a model of the form:
6363
6464
For example, here's a recurrent network that keeps a running total of its inputs:
6565
66-
```julia
67-
accum(h, x) = (h + x, x)
68-
rnn = Flux.Recur(accum, 0)
69-
rnn(2) # 2
70-
rnn(3) # 3
71-
rnn.state # 5
72-
rnn.(1:10) # apply to a sequence
73-
rnn.state # 60
66+
# Examples
67+
```jldoctest
68+
julia> accum(h, x) = (h + x, x)
69+
accum (generic function with 1 method)
70+
71+
julia> rnn = Flux.Recur(accum, 0)
72+
Recur(accum)
73+
74+
julia> rnn(2)
75+
2
76+
77+
julia> rnn(3)
78+
3
79+
80+
julia> rnn.state
81+
5
82+
83+
julia> rnn.(1:10) # apply to a sequence
84+
10-element Vector{Int64}:
85+
1
86+
2
87+
3
88+
4
89+
5
90+
6
91+
7
92+
8
93+
9
94+
10
95+
96+
julia> rnn.state
97+
60
7498
```
7599
76100
Folding over a 3d Array of dimensions `(features, batch, time)` is also supported:
77101
78-
```julia
79-
accum(h, x) = (h .+ x, x)
80-
rnn = Flux.Recur(accum, zeros(Int, 1, 1))
81-
rnn([2]) # 2
82-
rnn([3]) # 3
83-
rnn.state # 5
84-
rnn(reshape(1:10, 1, 1, :)) # apply to a sequence of (features, batch, time)
85-
rnn.state # 60
86-
```
102+
```jldoctest
103+
julia> accum(h, x) = (h .+ x, x)
104+
accum (generic function with 1 method)
105+
106+
julia> rnn = Flux.Recur(accum, zeros(Int, 1, 1))
107+
Recur(accum)
108+
109+
julia> rnn([2])
110+
1-element Vector{Int64}:
111+
2
112+
113+
julia> rnn([3])
114+
1-element Vector{Int64}:
115+
3
116+
117+
julia> rnn.state
118+
1×1 Matrix{Int64}:
119+
5
120+
121+
julia> rnn(reshape(1:10, 1, 1, :)) # apply to a sequence of (features, batch, time)
122+
1×1×10 Array{Int64, 3}:
123+
[:, :, 1] =
124+
1
125+
126+
[:, :, 2] =
127+
2
128+
129+
[:, :, 3] =
130+
3
131+
132+
[:, :, 4] =
133+
4
87134
135+
[:, :, 5] =
136+
5
137+
138+
[:, :, 6] =
139+
6
140+
141+
[:, :, 7] =
142+
7
143+
144+
[:, :, 8] =
145+
8
146+
147+
[:, :, 9] =
148+
9
149+
150+
[:, :, 10] =
151+
10
152+
153+
julia> rnn.state
154+
1×1 Matrix{Int64}:
155+
60
156+
```
88157
"""
89158
mutable struct Recur{T,S}
90159
cell::T
@@ -107,8 +176,36 @@ Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
107176
Reset the hidden state of a recurrent layer back to its original value.
108177
109178
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
110-
```julia
111-
rnn.state = hidden(rnn.cell)
179+
180+
rnn.state = hidden(rnn.cell)
181+
182+
# Examples
183+
```jldoctest
184+
julia> r = RNN(3 => 5);
185+
186+
julia> r.state
187+
5×1 Matrix{Float32}:
188+
0.0
189+
0.0
190+
0.0
191+
0.0
192+
0.0
193+
194+
julia> r(rand(Float32, 3)); r.state
195+
5×1 Matrix{Float32}:
196+
-0.32719195
197+
-0.45280662
198+
-0.50386846
199+
-0.14782222
200+
0.23584609
201+
202+
julia> Flux.reset!(r)
203+
5×1 Matrix{Float32}:
204+
0.0
205+
0.0
206+
0.0
207+
0.0
208+
0.0
112209
```
113210
"""
114211
reset!(m::Recur) = (m.state = m.cell.state0)

0 commit comments

Comments
 (0)