1
1
using LinearAlgebra
2
2
3
+ @testset " RNN gradients-implicit" begin
4
+ layer = Flux. Recur (Flux. RNNCell (1 , 1 , identity))
5
+ layer. cell. Wi .= 5.0
6
+ layer. cell. Wh .= 4.0
7
+ layer. cell. b .= 0.0f0
8
+ layer. cell. state0 .= 7.0
9
+ x = [[2.0f0 ], [3.0f0 ]]
10
+
11
+ # theoretical primal gradients
12
+ primal =
13
+ layer. cell. Wh .* (layer. cell. Wh * layer. cell. state0 .+ x[1 ] .* layer. cell. Wi) .+
14
+ x[2 ] .* layer. cell. Wi
15
+ ∇Wi = x[1 ] .* layer. cell. Wh .+ x[2 ]
16
+ ∇Wh = 2 .* layer. cell. Wh .* layer. cell. state0 .+ x[1 ] .* layer. cell. Wi
17
+ ∇b = layer. cell. Wh .+ 1
18
+ ∇state0 = layer. cell. Wh .^ 2
19
+
20
+ Flux. reset! (layer)
21
+ ps = Flux. params (layer)
22
+ e, g = Flux. withgradient (ps) do
23
+ out = [layer (xi) for xi in x]
24
+ sum (out[2 ])
25
+ end
26
+
27
+ @test primal[1 ] ≈ e
28
+ @test ∇Wi ≈ g[ps[1 ]]
29
+ @test ∇Wh ≈ g[ps[2 ]]
30
+ @test ∇b ≈ g[ps[3 ]]
31
+ @test ∇state0 ≈ g[ps[4 ]]
32
+
33
+ end
34
+
35
+ @testset " RNN gradients-explicit" begin
36
+ layer = Flux. Recur (Flux. RNNCell (1 , 1 , identity))
37
+ layer. cell. Wi .= 5.0f0
38
+ layer. cell. Wh .= 4.0f0
39
+ layer. cell. b .= 0.0f0
40
+ layer. cell. state0 .= 7.0f0
41
+ x = [[2.0f0 ], [3.0f0 ]]
42
+
43
+ # theoretical primal gradients
44
+ primal =
45
+ layer. cell. Wh .* (layer. cell. Wh * layer. cell. state0 .+ x[1 ] .* layer. cell. Wi) .+
46
+ x[2 ] .* layer. cell. Wi
47
+ ∇Wi = x[1 ] .* layer. cell. Wh .+ x[2 ]
48
+ ∇Wh = 2 .* layer. cell. Wh .* layer. cell. state0 .+ x[1 ] .* layer. cell. Wi
49
+ ∇b = layer. cell. Wh .+ 1
50
+ ∇state0 = layer. cell. Wh .^ 2
51
+
52
+ Flux. reset! (layer)
53
+ e, g = Flux. withgradient (layer) do m
54
+ out = [m (xi) for xi in x]
55
+ sum (out[2 ])
56
+ end
57
+ grads = g[1 ][:cell ]
58
+
59
+ @test primal[1 ] ≈ e
60
+
61
+ if VERSION < v " 1.7"
62
+ @test ∇Wi ≈ grads[:Wi ]
63
+ @test ∇Wh ≈ grads[:Wh ]
64
+ @test ∇b ≈ grads[:b ]
65
+ @test_broken ∇state0 ≈ grads[:state0 ]
66
+ else
67
+ @test_broken ∇Wi ≈ grads[:Wi ]
68
+ @test_broken ∇Wh ≈ grads[:Wh ]
69
+ @test_broken ∇b ≈ grads[:b ]
70
+ @test_broken ∇state0 ≈ grads[:state0 ]
71
+ end
72
+ end
73
+
3
74
# Ref FluxML/Flux.jl#1209 1D input
4
75
@testset " BPTT-1D" begin
5
76
seq = [rand (Float32, 2 ) for i = 1 : 3 ]
183
254
@test m4 (x) isa Matrix{Float32}
184
255
@test (@inferred m4 (x); true )
185
256
@test Flux. outputsize (m4, size (x)) == size (m4 (x))
186
- end
257
+ end
0 commit comments