Skip to content

Commit 130af41

Browse files
hotfix LSTM ouput (#2547)
1 parent 428be48 commit 130af41

File tree

4 files changed

+87
-134
lines changed

4 files changed

+87
-134
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ CUDA = "5"
4848
ChainRulesCore = "1.12"
4949
Compat = "4.10.0"
5050
Enzyme = "0.13"
51-
Functors = "0.5"
5251
EnzymeCore = "0.7.7, 0.8.4"
52+
Functors = "0.5"
5353
MLDataDevices = "1.4.2"
5454
MLUtils = "0.4"
5555
MPI = "0.20.19"

src/layers/recurrent.jl

Lines changed: 39 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1+
out_from_state(state) = state
2+
out_from_state(state::Tuple) = state[1]
3+
4+
function scan(cell, x, state0)
5+
state = state0
6+
y = []
7+
for x_t in eachslice(x, dims = 2)
8+
state = cell(x_t, state)
9+
out = out_from_state(state)
10+
y = vcat(y, [out])
11+
end
12+
return stack(y, dims = 2)
13+
end
14+
115

2-
# Vanilla RNN
316

17+
# Vanilla RNN
418
@doc raw"""
519
RNNCell(in => out, σ = tanh; init_kernel = glorot_uniform,
620
init_recurrent_kernel = glorot_uniform, bias = true)
@@ -215,13 +229,7 @@ function (m::RNN)(x::AbstractArray, h)
215229
@assert ndims(x) == 2 || ndims(x) == 3
216230
# [x] = [in, L] or [in, L, B]
217231
# [h] = [out] or [out, B]
218-
y = []
219-
for x_t in eachslice(x, dims = 2)
220-
h = m.cell(x_t, h)
221-
# y = [y..., h]
222-
y = vcat(y, [h])
223-
end
224-
return stack(y, dims = 2)
232+
return scan(m.cell, x, h)
225233
end
226234

227235

@@ -297,22 +305,20 @@ function initialstates(lstm:: LSTMCell)
297305
end
298306

299307
function LSTMCell(
300-
(in, out)::Pair;
301-
init_kernel = glorot_uniform,
302-
init_recurrent_kernel = glorot_uniform,
303-
bias = true,
304-
)
308+
(in, out)::Pair;
309+
init_kernel = glorot_uniform,
310+
init_recurrent_kernel = glorot_uniform,
311+
bias = true,
312+
)
313+
305314
Wi = init_kernel(out * 4, in)
306315
Wh = init_recurrent_kernel(out * 4, out)
307316
b = create_bias(Wi, bias, out * 4)
308317
cell = LSTMCell(Wi, Wh, b)
309318
return cell
310319
end
311320

312-
function (lstm::LSTMCell)(x::AbstractVecOrMat)
313-
state, cstate = initialstates(lstm)
314-
return lstm(x, (state, cstate))
315-
end
321+
(lstm::LSTMCell)(x::AbstractVecOrMat) = lstm(x, initialstates(lstm))
316322

317323
function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
318324
_size_check(m, x, 1 => size(m.Wi, 2))
@@ -368,15 +374,14 @@ The arguments of the forward pass are:
368374
They should be vectors of size `out` or matrices of size `out x batch_size`.
369375
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
370376
371-
Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
372-
in tensors of size `out x len` or `out x len x batch_size`.
377+
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`.
373378
374379
# Examples
375380
376381
```julia
377382
struct Model
378383
lstm::LSTM
379-
h0::AbstractVector
384+
h0::AbstractVector # trainable initial hidden state
380385
c0::AbstractVector
381386
end
382387
@@ -387,7 +392,7 @@ Flux.@layer Model
387392
d_in, d_out, len, batch_size = 2, 3, 4, 5
388393
x = rand(Float32, (d_in, len, batch_size))
389394
model = Model(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
390-
h, c = model(x)
395+
h = model(x)
391396
size(h) # out x len x batch_size
392397
```
393398
"""
@@ -404,21 +409,11 @@ function LSTM((in, out)::Pair; cell_kwargs...)
404409
return LSTM(cell)
405410
end
406411

407-
function (lstm::LSTM)(x::AbstractArray)
408-
state, cstate = initialstates(lstm)
409-
return lstm(x, (state, cstate))
410-
end
412+
(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm))
411413

412-
function (m::LSTM)(x::AbstractArray, (h, c))
414+
function (m::LSTM)(x::AbstractArray, state0)
413415
@assert ndims(x) == 2 || ndims(x) == 3
414-
h′ = []
415-
c′ = []
416-
for x_t in eachslice(x, dims = 2)
417-
h, c = m.cell(x_t, (h, c))
418-
h′ = vcat(h′, [h])
419-
c′ = vcat(c′, [c])
420-
end
421-
return stack(h′, dims = 2), stack(c′, dims = 2)
416+
return scan(m.cell, x, state0)
422417
end
423418

424419
# GRU
@@ -485,11 +480,12 @@ end
485480
initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))
486481

487482
function GRUCell(
488-
(in, out)::Pair;
489-
init_kernel = glorot_uniform,
490-
init_recurrent_kernel = glorot_uniform,
491-
bias = true,
492-
)
483+
(in, out)::Pair;
484+
init_kernel = glorot_uniform,
485+
init_recurrent_kernel = glorot_uniform,
486+
bias = true,
487+
)
488+
493489
Wi = init_kernel(out * 3, in)
494490
Wh = init_recurrent_kernel(out * 3, out)
495491
b = create_bias(Wi, bias, size(Wi, 1))
@@ -581,20 +577,11 @@ function GRU((in, out)::Pair; cell_kwargs...)
581577
return GRU(cell)
582578
end
583579

584-
function (gru::GRU)(x::AbstractArray)
585-
state = initialstates(gru)
586-
return gru(x, state)
587-
end
580+
(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru))
588581

589582
function (m::GRU)(x::AbstractArray, h)
590583
@assert ndims(x) == 2 || ndims(x) == 3
591-
h′ = []
592-
# [x] = [in, L] or [in, L, B]
593-
for x_t in eachslice(x, dims = 2)
594-
h = m.cell(x_t, h)
595-
h′ = vcat(h′, [h])
596-
end
597-
return stack(h′, dims = 2)
584+
return scan(m.cell, x, h)
598585
end
599586

600587
# GRU v3
@@ -750,17 +737,9 @@ function GRUv3((in, out)::Pair; cell_kwargs...)
750737
return GRUv3(cell)
751738
end
752739

753-
function (gru::GRUv3)(x::AbstractArray)
754-
state = initialstates(gru)
755-
return gru(x, state)
756-
end
740+
(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru))
757741

758742
function (m::GRUv3)(x::AbstractArray, h)
759743
@assert ndims(x) == 2 || ndims(x) == 3
760-
h′ = []
761-
for x_t in eachslice(x, dims = 2)
762-
h = m.cell(x_t, h)
763-
h′ = vcat(h′, [h])
764-
end
765-
return stack(h′, dims = 2)
744+
return scan(m.cell, x, h)
766745
end

test/ext_common/recurrent_gpu_ad.jl

Lines changed: 40 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1-
2-
@testset "RNNCell GPU AD" begin
3-
function loss(r, x, h)
4-
y = []
5-
for x_t in x
6-
h = r(x_t, h)
7-
y = vcat(y, [h])
8-
end
9-
# return mean(h)
10-
y = stack(y, dims=2) # [D, L] or [D, L, B]
11-
return mean(y)
1+
out_from_state(state::Tuple) = state[1]
2+
out_from_state(state) = state
3+
4+
function recurrent_cell_loss(cell, seq, state)
5+
out = []
6+
for xt in seq
7+
state = cell(xt, state)
8+
yt = out_from_state(state)
9+
out = vcat(out, [yt])
1210
end
11+
return mean(stack(out, dims = 2))
12+
end
1313

14+
@testset "RNNCell GPU AD" begin
1415
d_in, d_out, len, batch_size = 2, 3, 4, 5
1516
r = RNNCell(d_in => d_out)
1617
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
1718
h = zeros(Float32, d_out)
1819
# Single Step
19-
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
20+
@test test_gradients(r, x[1], h; test_gpu=true,
21+
compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
2022
# Multiple Steps
21-
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple BROKEN_TESTS
23+
@test test_gradients(r, x, h; test_gpu=true,
24+
compare_finite_diff=false,
25+
loss=recurrent_cell_loss) broken = :rnncell_multiple BROKEN_TESTS
2226
end
2327

2428
@testset "RNN GPU AD" begin
@@ -40,21 +44,6 @@ end
4044
end
4145

4246
@testset "LSTMCell" begin
43-
44-
function loss(r, x, hc)
45-
h, c = hc
46-
h′ = []
47-
c′ = []
48-
for x_t in x
49-
h, c = r(x_t, (h, c))
50-
h′ = vcat(h′, [h])
51-
c′ = [c′..., c]
52-
end
53-
hnew = stack(h′, dims=2)
54-
cnew = stack(c′, dims=2)
55-
return mean(hnew) + mean(cnew)
56-
end
57-
5847
d_in, d_out, len, batch_size = 2, 3, 4, 5
5948
cell = LSTMCell(d_in => d_out)
6049
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
@@ -64,7 +53,9 @@ end
6453
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
6554
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single BROKEN_TESTS
6655
# Multiple Steps
67-
@test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple BROKEN_TESTS
56+
@test test_gradients(cell, x, (h, c); test_gpu=true,
57+
compare_finite_diff = false,
58+
loss = recurrent_cell_loss) broken = :lstmcell_multiple BROKEN_TESTS
6859
end
6960

7061
@testset "LSTM" begin
@@ -81,30 +72,22 @@ end
8172
d_in, d_out, len, batch_size = 2, 3, 4, 5
8273
model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
8374
x_nobatch = randn(Float32, d_in, len)
84-
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false,
85-
loss = (m, x) -> mean(m(x)[1])) broken = :lstm_nobatch BROKEN_TESTS
75+
@test test_gradients(model, x_nobatch; test_gpu=true,
76+
compare_finite_diff=false) broken = :lstm_nobatch BROKEN_TESTS
8677
x = randn(Float32, d_in, len, batch_size)
87-
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false,
88-
loss = (m, x) -> mean(m(x)[1])) broken = :lstm BROKEN_TESTS
78+
@test test_gradients(model, x; test_gpu=true,
79+
compare_finite_diff=false) broken = :lstm BROKEN_TESTS
8980
end
9081

9182
@testset "GRUCell" begin
92-
function loss(r, x, h)
93-
y = []
94-
for x_t in x
95-
h = r(x_t, h)
96-
y = vcat(y, [h])
97-
end
98-
y = stack(y, dims=2) # [D, L] or [D, L, B]
99-
return mean(y)
100-
end
101-
10283
d_in, d_out, len, batch_size = 2, 3, 4, 5
10384
r = GRUCell(d_in => d_out)
10485
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
10586
h = zeros(Float32, d_out)
10687
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single BROKEN_TESTS
107-
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple BROKEN_TESTS
88+
@test test_gradients(r, x, h; test_gpu=true,
89+
compare_finite_diff = false,
90+
loss = recurrent_cell_loss) broken = :grucell_multiple BROKEN_TESTS
10891
end
10992

11093
@testset "GRU GPU AD" begin
@@ -120,28 +103,23 @@ end
120103
d_in, d_out, len, batch_size = 2, 3, 4, 5
121104
model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out))
122105
x_nobatch = randn(Float32, d_in, len)
123-
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gru_nobatch BROKEN_TESTS
106+
@test test_gradients(model, x_nobatch; test_gpu=true,
107+
compare_finite_diff=false) broken = :gru_nobatch BROKEN_TESTS
124108
x = randn(Float32, d_in, len, batch_size)
125-
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gru BROKEN_TESTS
109+
@test test_gradients(model, x; test_gpu=true,
110+
compare_finite_diff=false) broken = :gru BROKEN_TESTS
126111
end
127112

128113
@testset "GRUv3Cell GPU AD" begin
129-
function loss(r, x, h)
130-
y = []
131-
for x_t in x
132-
h = r(x_t, h)
133-
y = vcat(y, [h])
134-
end
135-
y = stack(y, dims=2) # [D, L] or [D, L, B]
136-
return mean(y)
137-
end
138-
139114
d_in, d_out, len, batch_size = 2, 3, 4, 5
140115
r = GRUv3Cell(d_in => d_out)
141116
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
142117
h = zeros(Float32, d_out)
143-
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
144-
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple BROKEN_TESTS
118+
@test test_gradients(r, x[1], h; test_gpu=true,
119+
compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
120+
@test test_gradients(r, x, h; test_gpu=true,
121+
compare_finite_diff=false,
122+
loss = recurrent_cell_loss) broken = :gruv3cell_multiple BROKEN_TESTS
145123
end
146124

147125
@testset "GRUv3 GPU AD" begin
@@ -157,7 +135,9 @@ end
157135
d_in, d_out, len, batch_size = 2, 3, 4, 5
158136
model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out))
159137
x_nobatch = randn(Float32, d_in, len)
160-
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gruv3_nobatch BROKEN_TESTS
138+
@test test_gradients(model, x_nobatch; test_gpu=true,
139+
compare_finite_diff=false) broken = :gruv3_nobatch BROKEN_TESTS
161140
x = randn(Float32, d_in, len, batch_size)
162-
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gruv3 BROKEN_TESTS
141+
@test test_gradients(model, x; test_gpu=true,
142+
compare_finite_diff=false) broken = :gruv3 BROKEN_TESTS
163143
end

test/layers/recurrent.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,37 +156,31 @@ end
156156
model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4))
157157

158158
x = rand(Float32, 2, 3, 1)
159-
h, c = model(x)
159+
h = model(x)
160160
@test h isa Array{Float32, 3}
161161
@test size(h) == (4, 3, 1)
162-
@test c isa Array{Float32, 3}
163-
@test size(c) == (4, 3, 1)
164-
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))
162+
test_gradients(model, x)
165163

166164
x = rand(Float32, 2, 3)
167-
h, c = model(x)
165+
h = model(x)
168166
@test h isa Array{Float32, 2}
169167
@test size(h) == (4, 3)
170-
@test c isa Array{Float32, 2}
171-
@test size(c) == (4, 3)
172168
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))
173169

170+
# test default initial states
174171
lstm = model.lstm
175-
h, c = lstm(x)
172+
h = lstm(x)
176173
@test h isa Array{Float32, 2}
177174
@test size(h) == (4, 3)
178-
@test c isa Array{Float32, 2}
179-
@test size(c) == (4, 3)
180-
175+
181176
# initial states are zero
182177
h0, c0 = Flux.initialstates(lstm)
183178
@test h0 zeros(Float32, 4)
184179
@test c0 zeros(Float32, 4)
185180

186181
# no initial state same as zero initial state
187-
h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
182+
h1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
188183
@test h h1
189-
@test c c1
190184
end
191185

192186
@testset "GRUCell" begin

0 commit comments

Comments
 (0)