@@ -63,28 +63,97 @@ in the background. `cell` should be a model of the form:
63
63
64
64
For example, here's a recurrent network that keeps a running total of its inputs:
65
65
66
- ```julia
67
- accum(h, x) = (h + x, x)
68
- rnn = Flux.Recur(accum, 0)
69
- rnn(2) # 2
70
- rnn(3) # 3
71
- rnn.state # 5
72
- rnn.(1:10) # apply to a sequence
73
- rnn.state # 60
66
+ # Examples
67
+ ```jldoctest
68
+ julia> accum(h, x) = (h + x, x)
69
+ accum (generic function with 1 method)
70
+
71
+ julia> rnn = Flux.Recur(accum, 0)
72
+ Recur(accum)
73
+
74
+ julia> rnn(2)
75
+ 2
76
+
77
+ julia> rnn(3)
78
+ 3
79
+
80
+ julia> rnn.state
81
+ 5
82
+
83
+ julia> rnn.(1:10) # apply to a sequence
84
+ 10-element Vector{Int64}:
85
+ 1
86
+ 2
87
+ 3
88
+ 4
89
+ 5
90
+ 6
91
+ 7
92
+ 8
93
+ 9
94
+ 10
95
+
96
+ julia> rnn.state
97
+ 60
74
98
```
75
99
76
100
Folding over a 3d Array of dimensions `(features, batch, time)` is also supported:
77
101
78
- ```julia
79
- accum(h, x) = (h .+ x, x)
80
- rnn = Flux.Recur(accum, zeros(Int, 1, 1))
81
- rnn([2]) # 2
82
- rnn([3]) # 3
83
- rnn.state # 5
84
- rnn(reshape(1:10, 1, 1, :)) # apply to a sequence of (features, batch, time)
85
- rnn.state # 60
86
- ```
102
+ ```jldoctest
103
+ julia> accum(h, x) = (h .+ x, x)
104
+ accum (generic function with 1 method)
105
+
106
+ julia> rnn = Flux.Recur(accum, zeros(Int, 1, 1))
107
+ Recur(accum)
108
+
109
+ julia> rnn([2])
110
+ 1-element Vector{Int64}:
111
+ 2
112
+
113
+ julia> rnn([3])
114
+ 1-element Vector{Int64}:
115
+ 3
116
+
117
+ julia> rnn.state
118
+ 1×1 Matrix{Int64}:
119
+ 5
120
+
121
+ julia> rnn(reshape(1:10, 1, 1, :)) # apply to a sequence of (features, batch, time)
122
+ 1×1×10 Array{Int64, 3}:
123
+ [:, :, 1] =
124
+ 1
125
+
126
+ [:, :, 2] =
127
+ 2
128
+
129
+ [:, :, 3] =
130
+ 3
131
+
132
+ [:, :, 4] =
133
+ 4
87
134
135
+ [:, :, 5] =
136
+ 5
137
+
138
+ [:, :, 6] =
139
+ 6
140
+
141
+ [:, :, 7] =
142
+ 7
143
+
144
+ [:, :, 8] =
145
+ 8
146
+
147
+ [:, :, 9] =
148
+ 9
149
+
150
+ [:, :, 10] =
151
+ 10
152
+
153
+ julia> rnn.state
154
+ 1×1 Matrix{Int64}:
155
+ 60
156
+ ```
88
157
"""
89
158
mutable struct Recur{T,S}
90
159
cell:: T
@@ -107,8 +176,36 @@ Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
107
176
Reset the hidden state of a recurrent layer back to its original value.
108
177
109
178
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
110
- ```julia
111
- rnn.state = hidden(rnn.cell)
179
+
180
+ rnn.state = hidden(rnn.cell)
181
+
182
+ # Examples
183
+ ```jldoctest
184
+ julia> r = RNN(3 => 5);
185
+
186
+ julia> r.state
187
+ 5×1 Matrix{Float32}:
188
+ 0.0
189
+ 0.0
190
+ 0.0
191
+ 0.0
192
+ 0.0
193
+
194
+ julia> r(rand(Float32, 3)); r.state
195
+ 5×1 Matrix{Float32}:
196
+ -0.32719195
197
+ -0.45280662
198
+ -0.50386846
199
+ -0.14782222
200
+ 0.23584609
201
+
202
+ julia> Flux.reset!(r)
203
+ 5×1 Matrix{Float32}:
204
+ 0.0
205
+ 0.0
206
+ 0.0
207
+ 0.0
208
+ 0.0
112
209
```
113
210
"""
114
211
reset! (m:: Recur ) = (m. state = m. cell. state0)
0 commit comments