Skip to content

Commit cba35ff

Browse files
committed
new RNN tests - format revert
1 parent 15c285e commit cba35ff

File tree

1 file changed

+140
-172
lines changed

1 file changed

+140
-172
lines changed

test/layers/recurrent.jl

Lines changed: 140 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -73,137 +73,122 @@ end
7373

7474
# Ref FluxML/Flux.jl#1209 1D input
7575
@testset "BPTT-1D" begin
76-
seq = [rand(Float32, 2) for i = 1:3]
77-
for r [RNN]
78-
rnn = r(2 => 3)
79-
Flux.reset!(rnn)
80-
grads_seq = gradient(Flux.params(rnn)) do
81-
sum([rnn(s) for s in seq][3])
82-
end
83-
Flux.reset!(rnn)
84-
bptt = gradient(
85-
Wh -> sum(
86-
tanh.(
87-
rnn.cell.Wi * seq[3] +
88-
Wh *
89-
tanh.(
90-
rnn.cell.Wi * seq[2] +
91-
Wh *
92-
tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.cell.state0 + rnn.cell.b) +
93-
rnn.cell.b
94-
) +
95-
rnn.cell.b
96-
),
97-
),
98-
rnn.cell.Wh,
99-
)
100-
@test grads_seq[rnn.cell.Wh] bptt[1]
76+
seq = [rand(Float32, 2) for i = 1:3]
77+
for r [RNN,]
78+
rnn = r(2 => 3)
79+
Flux.reset!(rnn)
80+
grads_seq = gradient(Flux.params(rnn)) do
81+
sum([rnn(s) for s in seq][3])
10182
end
83+
Flux.reset!(rnn);
84+
bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
85+
tanh.(rnn.cell.Wi * seq[2] + Wh *
86+
tanh.(rnn.cell.Wi * seq[1] +
87+
Wh * rnn.cell.state0
88+
+ rnn.cell.b)
89+
+ rnn.cell.b)
90+
+ rnn.cell.b)),
91+
rnn.cell.Wh)
92+
@test grads_seq[rnn.cell.Wh] bptt[1]
93+
end
10294
end
10395

10496
# Ref FluxML/Flux.jl#1209 2D input
10597
@testset "BPTT-2D" begin
106-
seq = [rand(Float32, (2, 1)) for i = 1:3]
107-
for r [RNN]
108-
rnn = r(2 => 3)
109-
Flux.reset!(rnn)
110-
grads_seq = gradient(Flux.params(rnn)) do
111-
sum([rnn(s) for s in seq][3])
112-
end
113-
Flux.reset!(rnn)
114-
bptt = gradient(
115-
Wh -> sum(
116-
tanh.(
117-
rnn.cell.Wi * seq[3] +
118-
Wh *
119-
tanh.(
120-
rnn.cell.Wi * seq[2] +
121-
Wh *
122-
tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.cell.state0 + rnn.cell.b) +
123-
rnn.cell.b
124-
) +
125-
rnn.cell.b
126-
),
127-
),
128-
rnn.cell.Wh,
129-
)
130-
@test grads_seq[rnn.cell.Wh] bptt[1]
131-
end
132-
end
133-
134-
@testset "BPTT-3D" begin
135-
seq = rand(Float32, (2, 1, 3))
136-
rnn = RNN(2 => 3)
98+
seq = [rand(Float32, (2, 1)) for i = 1:3]
99+
for r [RNN,]
100+
rnn = r(2 => 3)
137101
Flux.reset!(rnn)
138102
grads_seq = gradient(Flux.params(rnn)) do
139-
sum(rnn(seq)[:, :, 3])
140-
end
141-
Flux.reset!(rnn)
142-
bptt = gradient(rnn.cell.Wh) do Wh
143-
# calculate state 1
144-
s1 = tanh.(rnn.cell.Wi * seq[:, :, 1] + Wh * rnn.cell.state0 + rnn.cell.b)
145-
#calculate state 2
146-
s2 = tanh.(rnn.cell.Wi * seq[:, :, 2] + Wh * s1 + rnn.cell.b)
147-
#calculate state 3
148-
s3 = tanh.(rnn.cell.Wi * seq[:, :, 3] + Wh * s2 + rnn.cell.b)
149-
sum(s3) # loss is sum of state 3
103+
sum([rnn(s) for s in seq][3])
150104
end
105+
Flux.reset!(rnn);
106+
bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
107+
tanh.(rnn.cell.Wi * seq[2] + Wh *
108+
tanh.(rnn.cell.Wi * seq[1] +
109+
Wh * rnn.cell.state0
110+
+ rnn.cell.b)
111+
+ rnn.cell.b)
112+
+ rnn.cell.b)),
113+
rnn.cell.Wh)
151114
@test grads_seq[rnn.cell.Wh] bptt[1]
115+
end
116+
end
117+
118+
@testset "BPTT-3D" begin
119+
seq = rand(Float32, (2, 1, 3))
120+
rnn = RNN(2 => 3)
121+
Flux.reset!(rnn)
122+
grads_seq = gradient(Flux.params(rnn)) do
123+
sum(rnn(seq)[:, :, 3])
124+
end
125+
Flux.reset!(rnn);
126+
bptt = gradient(rnn.cell.Wh) do Wh
127+
# calculate state 1
128+
s1 = tanh.(rnn.cell.Wi * seq[:, :, 1] +
129+
Wh * rnn.cell.state0 +
130+
rnn.cell.b)
131+
#calculate state 2
132+
s2 = tanh.(rnn.cell.Wi * seq[:, :, 2] +
133+
Wh * s1 +
134+
rnn.cell.b)
135+
#calculate state 3
136+
s3 = tanh.(rnn.cell.Wi * seq[:, :, 3] +
137+
Wh * s2 +
138+
rnn.cell.b)
139+
sum(s3) # loss is sum of state 3
140+
end
141+
@test grads_seq[rnn.cell.Wh] bptt[1]
152142
end
153143

154144
@testset "RNN-shapes" begin
155-
@testset for R in [RNN, GRU, LSTM, GRUv3]
156-
m1 = R(3 => 5)
157-
m2 = R(3 => 5)
158-
m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation
159-
x1 = rand(Float32, 3)
160-
x2 = rand(Float32, 3, 1)
161-
x3 = rand(Float32, 3, 1, 2)
162-
Flux.reset!(m1)
163-
Flux.reset!(m2)
164-
Flux.reset!(m3)
165-
@test size(m1(x1)) == (5,)
166-
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
167-
@test size(m2(x2)) == (5, 1)
168-
@test size(m2(x2)) == (5, 1)
169-
@test size(m3(x3)) == (5, 1, 2)
170-
@test size(m3(x3)) == (5, 1, 2)
171-
end
145+
@testset for R in [RNN, GRU, LSTM, GRUv3]
146+
m1 = R(3 => 5)
147+
m2 = R(3 => 5)
148+
m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation
149+
x1 = rand(Float32, 3)
150+
x2 = rand(Float32, 3, 1)
151+
x3 = rand(Float32, 3, 1, 2)
152+
Flux.reset!(m1)
153+
Flux.reset!(m2)
154+
Flux.reset!(m3)
155+
@test size(m1(x1)) == (5,)
156+
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
157+
@test size(m2(x2)) == (5, 1)
158+
@test size(m2(x2)) == (5, 1)
159+
@test size(m3(x3)) == (5, 1, 2)
160+
@test size(m3(x3)) == (5, 1, 2)
161+
end
172162
end
173163

174164
@testset "multigate" begin
175-
x = rand(6, 5)
176-
res, (dx,) = Flux.withgradient(x) do x
177-
x1, _, x3 = Flux.multigate(x, 2, Val(3))
178-
sum(x1) + sum(x3 .* 2)
179-
end
180-
@test res == sum(x[1:2, :]) + 2sum(x[5:6, :])
181-
@test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)]
165+
x = rand(6, 5)
166+
res, (dx,) = Flux.withgradient(x) do x
167+
x1, _, x3 = Flux.multigate(x, 2, Val(3))
168+
sum(x1) + sum(x3 .* 2)
169+
end
170+
@test res == sum(x[1:2, :]) + 2sum(x[5:6, :])
171+
@test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)]
182172
end
183173

184174
@testset "eachlastdim" begin
185-
x = rand(3, 3, 1, 2, 4)
186-
@test length(Flux.eachlastdim(x)) == size(x, ndims(x))
187-
@test collect(@inferred(Flux.eachlastdim(x))) == collect(eachslice(x; dims = ndims(x)))
188-
slicedim = (size(x)[1:end-1]..., 1)
189-
res, (dx,) = Flux.withgradient(x) do x
190-
x1, _, x3, _ = Flux.eachlastdim(x)
191-
sum(x1) + sum(x3 .* 3)
192-
end
193-
@test res sum(selectdim(x, ndims(x), 1)) + 3sum(selectdim(x, ndims(x), 3))
194-
@test dx cat(
195-
fill(1, slicedim),
196-
fill(0, slicedim),
197-
fill(3, slicedim),
198-
fill(0, slicedim);
199-
dims = ndims(x),
200-
)
175+
x = rand(3, 3, 1, 2, 4)
176+
@test length(Flux.eachlastdim(x)) == size(x, ndims(x))
177+
@test collect(@inferred(Flux.eachlastdim(x))) == collect(eachslice(x; dims=ndims(x)))
178+
slicedim = (size(x)[1:end-1]..., 1)
179+
res, (dx,) = Flux.withgradient(x) do x
180+
x1, _, x3, _ = Flux.eachlastdim(x)
181+
sum(x1) + sum(x3 .* 3)
182+
end
183+
@test res sum(selectdim(x, ndims(x), 1)) + 3sum(selectdim(x, ndims(x), 3))
184+
@test dx cat(fill(1, slicedim), fill(0, slicedim),
185+
fill(3, slicedim), fill(0, slicedim); dims=ndims(x))
201186
end
202187

203188
@testset "∇eachlastdim" begin
204189
x = rand(3, 3, 1, 2, 4)
205190
x_size = size(x)
206-
y = collect(eachslice(x; dims = ndims(x)))
191+
y = collect(eachslice(x; dims=ndims(x)))
207192
@test @inferred(Flux.∇eachlastdim(y, x)) == x
208193
ZeroTangent = Flux.Zygote.ZeroTangent
209194
NoTangent = Flux.Zygote.NoTangent
@@ -212,78 +197,61 @@ end
212197
x2 = rand(Float64, x_size[1:end-1])
213198
x3 = rand(Float64, x_size[1:end-1])
214199
mixed_vector = [ZeroTangent(), x2, x3, ZeroTangent()]
215-
@test @inferred(Flux.∇eachlastdim(mixed_vector, x))
216-
cat(zeros(x_size[1:end-1]), x2, x3, zeros(x_size[1:end-1]); dims = ndims(x))
200+
@test @inferred(Flux.∇eachlastdim(mixed_vector, x)) cat(zeros(x_size[1:end-1]),
201+
x2,
202+
x3,
203+
zeros(x_size[1:end-1]); dims=ndims(x))
217204
end
218205

219206
@testset "Different Internal Matrix Types" begin
220-
R = Flux.Recur(
221-
Flux.RNNCell(tanh, rand(5, 3), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1)),
222-
)
223-
# don't want to pull in SparseArrays just for this test, but there aren't any
224-
# non-square structured matrix types in LinearAlgebra. so we will use a different
225-
# eltype matrix, which would fail before when `W_i` and `W_h` were required to be the
226-
# same type.
227-
L = Flux.Recur(
228-
Flux.LSTMCell(
229-
rand(5 * 4, 3),
230-
rand(1:20, 5 * 4, 5),
231-
rand(5 * 4),
232-
(rand(5, 1), rand(5, 1)),
233-
),
234-
)
235-
G = Flux.Recur(
236-
Flux.GRUCell(rand(5 * 3, 3), rand(1:20, 5 * 3, 5), rand(5 * 3), rand(5, 1)),
237-
)
238-
G3 = Flux.Recur(
239-
Flux.GRUv3Cell(
240-
rand(5 * 3, 3),
241-
rand(1:20, 5 * 2, 5),
242-
rand(5 * 3),
243-
Tridiagonal(rand(5, 5)),
244-
rand(5, 1),
245-
),
246-
)
207+
R = Flux.Recur(Flux.RNNCell(tanh, rand(5, 3), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1)))
208+
# don't want to pull in SparseArrays just for this test, but there aren't any
209+
# non-square structured matrix types in LinearAlgebra. so we will use a different
210+
# eltype matrix, which would fail before when `W_i` and `W_h` were required to be the
211+
# same type.
212+
L = Flux.Recur(Flux.LSTMCell(rand(5*4, 3), rand(1:20, 5*4, 5), rand(5*4), (rand(5, 1), rand(5, 1))))
213+
G = Flux.Recur(Flux.GRUCell(rand(5*3, 3), rand(1:20, 5*3, 5), rand(5*3), rand(5, 1)))
214+
G3 = Flux.Recur(Flux.GRUv3Cell(rand(5*3, 3), rand(1:20, 5*2, 5), rand(5*3), Tridiagonal(rand(5, 5)), rand(5, 1)))
247215

248-
for m in [R, L, G, G3]
216+
for m in [R, L, G, G3]
249217

250-
x1 = rand(3)
251-
x2 = rand(3, 1)
252-
x3 = rand(3, 1, 2)
253-
Flux.reset!(m)
254-
@test size(m(x1)) == (5,)
255-
Flux.reset!(m)
256-
@test size(m(x1)) == (5,) # repeat in case of effect from change in state shape
257-
@test size(m(x2)) == (5, 1)
258-
Flux.reset!(m)
259-
@test size(m(x2)) == (5, 1)
260-
Flux.reset!(m)
261-
@test size(m(x3)) == (5, 1, 2)
262-
Flux.reset!(m)
263-
@test size(m(x3)) == (5, 1, 2)
264-
end
218+
x1 = rand(3)
219+
x2 = rand(3, 1)
220+
x3 = rand(3, 1, 2)
221+
Flux.reset!(m)
222+
@test size(m(x1)) == (5,)
223+
Flux.reset!(m)
224+
@test size(m(x1)) == (5,) # repeat in case of effect from change in state shape
225+
@test size(m(x2)) == (5, 1)
226+
Flux.reset!(m)
227+
@test size(m(x2)) == (5, 1)
228+
Flux.reset!(m)
229+
@test size(m(x3)) == (5, 1, 2)
230+
Flux.reset!(m)
231+
@test size(m(x3)) == (5, 1, 2)
232+
end
265233
end
266234

267235
@testset "type matching" begin
268-
x = rand(Float64, 2, 4)
269-
m1 = RNN(2 => 3)
270-
@test m1(x) isa Matrix{Float32} # uses _match_eltype, may print a warning
271-
@test m1.state isa Matrix{Float32}
272-
@test (@inferred m1(x); true)
273-
@test Flux.outputsize(m1, size(x)) == size(m1(x))
236+
x = rand(Float64, 2, 4)
237+
m1 = RNN(2=>3)
238+
@test m1(x) isa Matrix{Float32} # uses _match_eltype, may print a warning
239+
@test m1.state isa Matrix{Float32}
240+
@test (@inferred m1(x); true)
241+
@test Flux.outputsize(m1, size(x)) == size(m1(x))
274242

275-
m2 = LSTM(2 => 3)
276-
@test m2(x) isa Matrix{Float32}
277-
@test (@inferred m2(x); true)
278-
@test Flux.outputsize(m2, size(x)) == size(m2(x))
243+
m2 = LSTM(2=>3)
244+
@test m2(x) isa Matrix{Float32}
245+
@test (@inferred m2(x); true)
246+
@test Flux.outputsize(m2, size(x)) == size(m2(x))
279247

280-
m3 = GRU(2 => 3)
281-
@test m3(x) isa Matrix{Float32}
282-
@test (@inferred m3(x); true)
283-
@test Flux.outputsize(m3, size(x)) == size(m3(x))
248+
m3 = GRU(2=>3)
249+
@test m3(x) isa Matrix{Float32}
250+
@test (@inferred m3(x); true)
251+
@test Flux.outputsize(m3, size(x)) == size(m3(x))
284252

285-
m4 = GRUv3(2 => 3)
286-
@test m4(x) isa Matrix{Float32}
287-
@test (@inferred m4(x); true)
288-
@test Flux.outputsize(m4, size(x)) == size(m4(x))
289-
end
253+
m4 = GRUv3(2=>3)
254+
@test m4(x) isa Matrix{Float32}
255+
@test (@inferred m4(x); true)
256+
@test Flux.outputsize(m4, size(x)) == size(m4(x))
257+
end

0 commit comments

Comments
 (0)