@@ -73,137 +73,122 @@ end
73
73
74
74
# Ref FluxML/Flux.jl#1209 1D input
75
75
@testset " BPTT-1D" begin
76
- seq = [rand (Float32, 2 ) for i = 1 : 3 ]
77
- for r ∈ [RNN]
78
- rnn = r (2 => 3 )
79
- Flux. reset! (rnn)
80
- grads_seq = gradient (Flux. params (rnn)) do
81
- sum ([rnn (s) for s in seq][3 ])
82
- end
83
- Flux. reset! (rnn)
84
- bptt = gradient (
85
- Wh -> sum (
86
- tanh .(
87
- rnn. cell. Wi * seq[3 ] +
88
- Wh *
89
- tanh .(
90
- rnn. cell. Wi * seq[2 ] +
91
- Wh *
92
- tanh .(rnn. cell. Wi * seq[1 ] + Wh * rnn. cell. state0 + rnn. cell. b) +
93
- rnn. cell. b
94
- ) +
95
- rnn. cell. b
96
- ),
97
- ),
98
- rnn. cell. Wh,
99
- )
100
- @test grads_seq[rnn. cell. Wh] ≈ bptt[1 ]
76
+ seq = [rand (Float32, 2 ) for i = 1 : 3 ]
77
+ for r ∈ [RNN,]
78
+ rnn = r (2 => 3 )
79
+ Flux. reset! (rnn)
80
+ grads_seq = gradient (Flux. params (rnn)) do
81
+ sum ([rnn (s) for s in seq][3 ])
101
82
end
83
+ Flux. reset! (rnn);
84
+ bptt = gradient (Wh -> sum (tanh .(rnn. cell. Wi * seq[3 ] + Wh *
85
+ tanh .(rnn. cell. Wi * seq[2 ] + Wh *
86
+ tanh .(rnn. cell. Wi * seq[1 ] +
87
+ Wh * rnn. cell. state0
88
+ + rnn. cell. b)
89
+ + rnn. cell. b)
90
+ + rnn. cell. b)),
91
+ rnn. cell. Wh)
92
+ @test grads_seq[rnn. cell. Wh] ≈ bptt[1 ]
93
+ end
102
94
end
103
95
104
96
# Ref FluxML/Flux.jl#1209 2D input
105
97
@testset " BPTT-2D" begin
106
- seq = [rand (Float32, (2 , 1 )) for i = 1 : 3 ]
107
- for r ∈ [RNN]
108
- rnn = r (2 => 3 )
109
- Flux. reset! (rnn)
110
- grads_seq = gradient (Flux. params (rnn)) do
111
- sum ([rnn (s) for s in seq][3 ])
112
- end
113
- Flux. reset! (rnn)
114
- bptt = gradient (
115
- Wh -> sum (
116
- tanh .(
117
- rnn. cell. Wi * seq[3 ] +
118
- Wh *
119
- tanh .(
120
- rnn. cell. Wi * seq[2 ] +
121
- Wh *
122
- tanh .(rnn. cell. Wi * seq[1 ] + Wh * rnn. cell. state0 + rnn. cell. b) +
123
- rnn. cell. b
124
- ) +
125
- rnn. cell. b
126
- ),
127
- ),
128
- rnn. cell. Wh,
129
- )
130
- @test grads_seq[rnn. cell. Wh] ≈ bptt[1 ]
131
- end
132
- end
133
-
134
- @testset " BPTT-3D" begin
135
- seq = rand (Float32, (2 , 1 , 3 ))
136
- rnn = RNN (2 => 3 )
98
+ seq = [rand (Float32, (2 , 1 )) for i = 1 : 3 ]
99
+ for r ∈ [RNN,]
100
+ rnn = r (2 => 3 )
137
101
Flux. reset! (rnn)
138
102
grads_seq = gradient (Flux. params (rnn)) do
139
- sum (rnn (seq)[:, :, 3 ])
140
- end
141
- Flux. reset! (rnn)
142
- bptt = gradient (rnn. cell. Wh) do Wh
143
- # calculate state 1
144
- s1 = tanh .(rnn. cell. Wi * seq[:, :, 1 ] + Wh * rnn. cell. state0 + rnn. cell. b)
145
- # calculate state 2
146
- s2 = tanh .(rnn. cell. Wi * seq[:, :, 2 ] + Wh * s1 + rnn. cell. b)
147
- # calculate state 3
148
- s3 = tanh .(rnn. cell. Wi * seq[:, :, 3 ] + Wh * s2 + rnn. cell. b)
149
- sum (s3) # loss is sum of state 3
103
+ sum ([rnn (s) for s in seq][3 ])
150
104
end
105
+ Flux. reset! (rnn);
106
+ bptt = gradient (Wh -> sum (tanh .(rnn. cell. Wi * seq[3 ] + Wh *
107
+ tanh .(rnn. cell. Wi * seq[2 ] + Wh *
108
+ tanh .(rnn. cell. Wi * seq[1 ] +
109
+ Wh * rnn. cell. state0
110
+ + rnn. cell. b)
111
+ + rnn. cell. b)
112
+ + rnn. cell. b)),
113
+ rnn. cell. Wh)
151
114
@test grads_seq[rnn. cell. Wh] ≈ bptt[1 ]
115
+ end
116
+ end
117
+
118
+ @testset " BPTT-3D" begin
119
+ seq = rand (Float32, (2 , 1 , 3 ))
120
+ rnn = RNN (2 => 3 )
121
+ Flux. reset! (rnn)
122
+ grads_seq = gradient (Flux. params (rnn)) do
123
+ sum (rnn (seq)[:, :, 3 ])
124
+ end
125
+ Flux. reset! (rnn);
126
+ bptt = gradient (rnn. cell. Wh) do Wh
127
+ # calculate state 1
128
+ s1 = tanh .(rnn. cell. Wi * seq[:, :, 1 ] +
129
+ Wh * rnn. cell. state0 +
130
+ rnn. cell. b)
131
+ # calculate state 2
132
+ s2 = tanh .(rnn. cell. Wi * seq[:, :, 2 ] +
133
+ Wh * s1 +
134
+ rnn. cell. b)
135
+ # calculate state 3
136
+ s3 = tanh .(rnn. cell. Wi * seq[:, :, 3 ] +
137
+ Wh * s2 +
138
+ rnn. cell. b)
139
+ sum (s3) # loss is sum of state 3
140
+ end
141
+ @test grads_seq[rnn. cell. Wh] ≈ bptt[1 ]
152
142
end
153
143
154
144
@testset " RNN-shapes" begin
155
- @testset for R in [RNN, GRU, LSTM, GRUv3]
156
- m1 = R (3 => 5 )
157
- m2 = R (3 => 5 )
158
- m3 = R (3 , 5 ) # leave one to test the silently deprecated "," not "=>" notation
159
- x1 = rand (Float32, 3 )
160
- x2 = rand (Float32, 3 , 1 )
161
- x3 = rand (Float32, 3 , 1 , 2 )
162
- Flux. reset! (m1)
163
- Flux. reset! (m2)
164
- Flux. reset! (m3)
165
- @test size (m1 (x1)) == (5 ,)
166
- @test size (m1 (x1)) == (5 ,) # repeat in case of effect from change in state shape
167
- @test size (m2 (x2)) == (5 , 1 )
168
- @test size (m2 (x2)) == (5 , 1 )
169
- @test size (m3 (x3)) == (5 , 1 , 2 )
170
- @test size (m3 (x3)) == (5 , 1 , 2 )
171
- end
145
+ @testset for R in [RNN, GRU, LSTM, GRUv3]
146
+ m1 = R (3 => 5 )
147
+ m2 = R (3 => 5 )
148
+ m3 = R (3 , 5 ) # leave one to test the silently deprecated "," not "=>" notation
149
+ x1 = rand (Float32, 3 )
150
+ x2 = rand (Float32, 3 , 1 )
151
+ x3 = rand (Float32, 3 , 1 , 2 )
152
+ Flux. reset! (m1)
153
+ Flux. reset! (m2)
154
+ Flux. reset! (m3)
155
+ @test size (m1 (x1)) == (5 ,)
156
+ @test size (m1 (x1)) == (5 ,) # repeat in case of effect from change in state shape
157
+ @test size (m2 (x2)) == (5 , 1 )
158
+ @test size (m2 (x2)) == (5 , 1 )
159
+ @test size (m3 (x3)) == (5 , 1 , 2 )
160
+ @test size (m3 (x3)) == (5 , 1 , 2 )
161
+ end
172
162
end
173
163
174
164
@testset " multigate" begin
175
- x = rand (6 , 5 )
176
- res, (dx,) = Flux. withgradient (x) do x
177
- x1, _, x3 = Flux. multigate (x, 2 , Val (3 ))
178
- sum (x1) + sum (x3 .* 2 )
179
- end
180
- @test res == sum (x[1 : 2 , :]) + 2 sum (x[5 : 6 , :])
181
- @test dx == [ones (2 , 5 ); zeros (2 , 5 ); fill (2 , 2 , 5 )]
165
+ x = rand (6 , 5 )
166
+ res, (dx,) = Flux. withgradient (x) do x
167
+ x1, _, x3 = Flux. multigate (x, 2 , Val (3 ))
168
+ sum (x1) + sum (x3 .* 2 )
169
+ end
170
+ @test res == sum (x[1 : 2 , :]) + 2 sum (x[5 : 6 , :])
171
+ @test dx == [ones (2 , 5 ); zeros (2 , 5 ); fill (2 , 2 , 5 )]
182
172
end
183
173
184
174
@testset " eachlastdim" begin
185
- x = rand (3 , 3 , 1 , 2 , 4 )
186
- @test length (Flux. eachlastdim (x)) == size (x, ndims (x))
187
- @test collect (@inferred (Flux. eachlastdim (x))) == collect (eachslice (x; dims = ndims (x)))
188
- slicedim = (size (x)[1 : end - 1 ]. .. , 1 )
189
- res, (dx,) = Flux. withgradient (x) do x
190
- x1, _, x3, _ = Flux. eachlastdim (x)
191
- sum (x1) + sum (x3 .* 3 )
192
- end
193
- @test res ≈ sum (selectdim (x, ndims (x), 1 )) + 3 sum (selectdim (x, ndims (x), 3 ))
194
- @test dx ≈ cat (
195
- fill (1 , slicedim),
196
- fill (0 , slicedim),
197
- fill (3 , slicedim),
198
- fill (0 , slicedim);
199
- dims = ndims (x),
200
- )
175
+ x = rand (3 , 3 , 1 , 2 , 4 )
176
+ @test length (Flux. eachlastdim (x)) == size (x, ndims (x))
177
+ @test collect (@inferred (Flux. eachlastdim (x))) == collect (eachslice (x; dims= ndims (x)))
178
+ slicedim = (size (x)[1 : end - 1 ]. .. , 1 )
179
+ res, (dx,) = Flux. withgradient (x) do x
180
+ x1, _, x3, _ = Flux. eachlastdim (x)
181
+ sum (x1) + sum (x3 .* 3 )
182
+ end
183
+ @test res ≈ sum (selectdim (x, ndims (x), 1 )) + 3 sum (selectdim (x, ndims (x), 3 ))
184
+ @test dx ≈ cat (fill (1 , slicedim), fill (0 , slicedim),
185
+ fill (3 , slicedim), fill (0 , slicedim); dims= ndims (x))
201
186
end
202
187
203
188
@testset " ∇eachlastdim" begin
204
189
x = rand (3 , 3 , 1 , 2 , 4 )
205
190
x_size = size (x)
206
- y = collect (eachslice (x; dims = ndims (x)))
191
+ y = collect (eachslice (x; dims= ndims (x)))
207
192
@test @inferred (Flux.∇eachlastdim (y, x)) == x
208
193
ZeroTangent = Flux. Zygote. ZeroTangent
209
194
NoTangent = Flux. Zygote. NoTangent
@@ -212,78 +197,61 @@ end
212
197
x2 = rand (Float64, x_size[1 : end - 1 ])
213
198
x3 = rand (Float64, x_size[1 : end - 1 ])
214
199
mixed_vector = [ZeroTangent (), x2, x3, ZeroTangent ()]
215
- @test @inferred (Flux.∇eachlastdim (mixed_vector, x)) ≈
216
- cat (zeros (x_size[1 : end - 1 ]), x2, x3, zeros (x_size[1 : end - 1 ]); dims = ndims (x))
200
+ @test @inferred (Flux.∇eachlastdim (mixed_vector, x)) ≈ cat (zeros (x_size[1 : end - 1 ]),
201
+ x2,
202
+ x3,
203
+ zeros (x_size[1 : end - 1 ]); dims= ndims (x))
217
204
end
218
205
219
206
@testset " Different Internal Matrix Types" begin
220
- R = Flux. Recur (
221
- Flux. RNNCell (tanh, rand (5 , 3 ), Tridiagonal (rand (5 , 5 )), rand (5 ), rand (5 , 1 )),
222
- )
223
- # don't want to pull in SparseArrays just for this test, but there aren't any
224
- # non-square structured matrix types in LinearAlgebra. so we will use a different
225
- # eltype matrix, which would fail before when `W_i` and `W_h` were required to be the
226
- # same type.
227
- L = Flux. Recur (
228
- Flux. LSTMCell (
229
- rand (5 * 4 , 3 ),
230
- rand (1 : 20 , 5 * 4 , 5 ),
231
- rand (5 * 4 ),
232
- (rand (5 , 1 ), rand (5 , 1 )),
233
- ),
234
- )
235
- G = Flux. Recur (
236
- Flux. GRUCell (rand (5 * 3 , 3 ), rand (1 : 20 , 5 * 3 , 5 ), rand (5 * 3 ), rand (5 , 1 )),
237
- )
238
- G3 = Flux. Recur (
239
- Flux. GRUv3Cell (
240
- rand (5 * 3 , 3 ),
241
- rand (1 : 20 , 5 * 2 , 5 ),
242
- rand (5 * 3 ),
243
- Tridiagonal (rand (5 , 5 )),
244
- rand (5 , 1 ),
245
- ),
246
- )
207
+ R = Flux. Recur (Flux. RNNCell (tanh, rand (5 , 3 ), Tridiagonal (rand (5 , 5 )), rand (5 ), rand (5 , 1 )))
208
+ # don't want to pull in SparseArrays just for this test, but there aren't any
209
+ # non-square structured matrix types in LinearAlgebra. so we will use a different
210
+ # eltype matrix, which would fail before when `W_i` and `W_h` were required to be the
211
+ # same type.
212
+ L = Flux. Recur (Flux. LSTMCell (rand (5 * 4 , 3 ), rand (1 : 20 , 5 * 4 , 5 ), rand (5 * 4 ), (rand (5 , 1 ), rand (5 , 1 ))))
213
+ G = Flux. Recur (Flux. GRUCell (rand (5 * 3 , 3 ), rand (1 : 20 , 5 * 3 , 5 ), rand (5 * 3 ), rand (5 , 1 )))
214
+ G3 = Flux. Recur (Flux. GRUv3Cell (rand (5 * 3 , 3 ), rand (1 : 20 , 5 * 2 , 5 ), rand (5 * 3 ), Tridiagonal (rand (5 , 5 )), rand (5 , 1 )))
247
215
248
- for m in [R, L, G, G3]
216
+ for m in [R, L, G, G3]
249
217
250
- x1 = rand (3 )
251
- x2 = rand (3 , 1 )
252
- x3 = rand (3 , 1 , 2 )
253
- Flux. reset! (m)
254
- @test size (m (x1)) == (5 ,)
255
- Flux. reset! (m)
256
- @test size (m (x1)) == (5 ,) # repeat in case of effect from change in state shape
257
- @test size (m (x2)) == (5 , 1 )
258
- Flux. reset! (m)
259
- @test size (m (x2)) == (5 , 1 )
260
- Flux. reset! (m)
261
- @test size (m (x3)) == (5 , 1 , 2 )
262
- Flux. reset! (m)
263
- @test size (m (x3)) == (5 , 1 , 2 )
264
- end
218
+ x1 = rand (3 )
219
+ x2 = rand (3 , 1 )
220
+ x3 = rand (3 , 1 , 2 )
221
+ Flux. reset! (m)
222
+ @test size (m (x1)) == (5 ,)
223
+ Flux. reset! (m)
224
+ @test size (m (x1)) == (5 ,) # repeat in case of effect from change in state shape
225
+ @test size (m (x2)) == (5 , 1 )
226
+ Flux. reset! (m)
227
+ @test size (m (x2)) == (5 , 1 )
228
+ Flux. reset! (m)
229
+ @test size (m (x3)) == (5 , 1 , 2 )
230
+ Flux. reset! (m)
231
+ @test size (m (x3)) == (5 , 1 , 2 )
232
+ end
265
233
end
266
234
267
235
@testset " type matching" begin
268
- x = rand (Float64, 2 , 4 )
269
- m1 = RNN (2 => 3 )
270
- @test m1 (x) isa Matrix{Float32} # uses _match_eltype, may print a warning
271
- @test m1. state isa Matrix{Float32}
272
- @test (@inferred m1 (x); true )
273
- @test Flux. outputsize (m1, size (x)) == size (m1 (x))
236
+ x = rand (Float64, 2 , 4 )
237
+ m1 = RNN (2 => 3 )
238
+ @test m1 (x) isa Matrix{Float32} # uses _match_eltype, may print a warning
239
+ @test m1. state isa Matrix{Float32}
240
+ @test (@inferred m1 (x); true )
241
+ @test Flux. outputsize (m1, size (x)) == size (m1 (x))
274
242
275
- m2 = LSTM (2 => 3 )
276
- @test m2 (x) isa Matrix{Float32}
277
- @test (@inferred m2 (x); true )
278
- @test Flux. outputsize (m2, size (x)) == size (m2 (x))
243
+ m2 = LSTM (2 => 3 )
244
+ @test m2 (x) isa Matrix{Float32}
245
+ @test (@inferred m2 (x); true )
246
+ @test Flux. outputsize (m2, size (x)) == size (m2 (x))
279
247
280
- m3 = GRU (2 => 3 )
281
- @test m3 (x) isa Matrix{Float32}
282
- @test (@inferred m3 (x); true )
283
- @test Flux. outputsize (m3, size (x)) == size (m3 (x))
248
+ m3 = GRU (2 => 3 )
249
+ @test m3 (x) isa Matrix{Float32}
250
+ @test (@inferred m3 (x); true )
251
+ @test Flux. outputsize (m3, size (x)) == size (m3 (x))
284
252
285
- m4 = GRUv3 (2 => 3 )
286
- @test m4 (x) isa Matrix{Float32}
287
- @test (@inferred m4 (x); true )
288
- @test Flux. outputsize (m4, size (x)) == size (m4 (x))
289
- end
253
+ m4 = GRUv3 (2 => 3 )
254
+ @test m4 (x) isa Matrix{Float32}
255
+ @test (@inferred m4 (x); true )
256
+ @test Flux. outputsize (m4, size (x)) == size (m4 (x))
257
+ end
0 commit comments