Skip to content

Commit 0038a60

Browse files
authored
Merge pull request #2215 from jeremiedb/jdb/rnn-debug
manual gradient checks for RNN - implicit and explicit gradients
2 parents 1ca15f3 + cba35ff commit 0038a60

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

test/layers/recurrent.jl

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,76 @@
11
using LinearAlgebra
22

3+
@testset "RNN gradients-implicit" begin
4+
layer = Flux.Recur(Flux.RNNCell(1, 1, identity))
5+
layer.cell.Wi .= 5.0
6+
layer.cell.Wh .= 4.0
7+
layer.cell.b .= 0.0f0
8+
layer.cell.state0 .= 7.0
9+
x = [[2.0f0], [3.0f0]]
10+
11+
# theoretical primal gradients
12+
primal =
13+
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
14+
x[2] .* layer.cell.Wi
15+
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
16+
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
17+
∇b = layer.cell.Wh .+ 1
18+
∇state0 = layer.cell.Wh .^ 2
19+
20+
Flux.reset!(layer)
21+
ps = Flux.params(layer)
22+
e, g = Flux.withgradient(ps) do
23+
out = [layer(xi) for xi in x]
24+
sum(out[2])
25+
end
26+
27+
@test primal[1] e
28+
@test ∇Wi g[ps[1]]
29+
@test ∇Wh g[ps[2]]
30+
@test ∇b g[ps[3]]
31+
@test ∇state0 g[ps[4]]
32+
33+
end
34+
35+
@testset "RNN gradients-explicit" begin
36+
layer = Flux.Recur(Flux.RNNCell(1, 1, identity))
37+
layer.cell.Wi .= 5.0f0
38+
layer.cell.Wh .= 4.0f0
39+
layer.cell.b .= 0.0f0
40+
layer.cell.state0 .= 7.0f0
41+
x = [[2.0f0], [3.0f0]]
42+
43+
# theoretical primal gradients
44+
primal =
45+
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
46+
x[2] .* layer.cell.Wi
47+
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
48+
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
49+
∇b = layer.cell.Wh .+ 1
50+
∇state0 = layer.cell.Wh .^ 2
51+
52+
Flux.reset!(layer)
53+
e, g = Flux.withgradient(layer) do m
54+
out = [m(xi) for xi in x]
55+
sum(out[2])
56+
end
57+
grads = g[1][:cell]
58+
59+
@test primal[1] e
60+
61+
if VERSION < v"1.7"
62+
@test ∇Wi grads[:Wi]
63+
@test ∇Wh grads[:Wh]
64+
@test ∇b grads[:b]
65+
@test_broken ∇state0 grads[:state0]
66+
else
67+
@test_broken ∇Wi grads[:Wi]
68+
@test_broken ∇Wh grads[:Wh]
69+
@test_broken ∇b grads[:b]
70+
@test_broken ∇state0 grads[:state0]
71+
end
72+
end
73+
374
# Ref FluxML/Flux.jl#1209 1D input
475
@testset "BPTT-1D" begin
576
seq = [rand(Float32, 2) for i = 1:3]
@@ -183,4 +254,4 @@ end
183254
@test m4(x) isa Matrix{Float32}
184255
@test (@inferred m4(x); true)
185256
@test Flux.outputsize(m4, size(x)) == size(m4(x))
186-
end
257+
end

0 commit comments

Comments
 (0)