1
-
2
- @testset " RNNCell GPU AD" begin
3
- function loss (r, x, h)
4
- y = []
5
- for x_t in x
6
- h = r (x_t, h)
7
- y = vcat (y, [h])
8
- end
9
- # return mean(h)
10
- y = stack (y, dims= 2 ) # [D, L] or [D, L, B]
11
- return mean (y)
1
+ out_from_state (state:: Tuple ) = state[1 ]
2
+ out_from_state (state) = state
3
+
4
+ function recurrent_cell_loss (cell, seq, state)
5
+ out = []
6
+ for xt in seq
7
+ state = cell (xt, state)
8
+ yt = out_from_state (state)
9
+ out = vcat (out, [yt])
12
10
end
11
+ return mean (stack (out, dims = 2 ))
12
+ end
13
13
14
+ @testset " RNNCell GPU AD" begin
14
15
d_in, d_out, len, batch_size = 2 , 3 , 4 , 5
15
16
r = RNNCell (d_in => d_out)
16
17
x = [randn (Float32, d_in, batch_size) for _ in 1 : len]
17
18
h = zeros (Float32, d_out)
18
19
# Single Step
19
- @test test_gradients (r, x[1 ], h; test_gpu= true , compare_finite_diff= false ) broken = :rnncell_single ∈ BROKEN_TESTS
20
+ @test test_gradients (r, x[1 ], h; test_gpu= true ,
21
+ compare_finite_diff= false ) broken = :rnncell_single ∈ BROKEN_TESTS
20
22
# Multiple Steps
21
- @test test_gradients (r, x, h; test_gpu= true , compare_finite_diff= false , loss) broken = :rnncell_multiple ∈ BROKEN_TESTS
23
+ @test test_gradients (r, x, h; test_gpu= true ,
24
+ compare_finite_diff= false ,
25
+ loss= recurrent_cell_loss) broken = :rnncell_multiple ∈ BROKEN_TESTS
22
26
end
23
27
24
28
@testset " RNN GPU AD" begin
40
44
end
41
45
42
46
@testset " LSTMCell" begin
43
-
44
- function loss (r, x, hc)
45
- h, c = hc
46
- h′ = []
47
- c′ = []
48
- for x_t in x
49
- h, c = r (x_t, (h, c))
50
- h′ = vcat (h′, [h])
51
- c′ = [c′... , c]
52
- end
53
- hnew = stack (h′, dims= 2 )
54
- cnew = stack (c′, dims= 2 )
55
- return mean (hnew) + mean (cnew)
56
- end
57
-
58
47
d_in, d_out, len, batch_size = 2 , 3 , 4 , 5
59
48
cell = LSTMCell (d_in => d_out)
60
49
x = [randn (Float32, d_in, batch_size) for _ in 1 : len]
64
53
@test test_gradients (cell, x[1 ], (h, c); test_gpu= true , compare_finite_diff= false ,
65
54
loss = (m, x, (h, c)) -> mean (m (x, (h, c))[1 ])) broken = :lstmcell_single ∈ BROKEN_TESTS
66
55
# Multiple Steps
67
- @test test_gradients (cell, x, (h, c); test_gpu= true , compare_finite_diff= false , loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS
56
+ @test test_gradients (cell, x, (h, c); test_gpu= true ,
57
+ compare_finite_diff = false ,
58
+ loss = recurrent_cell_loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS
68
59
end
69
60
70
61
@testset " LSTM" begin
81
72
d_in, d_out, len, batch_size = 2 , 3 , 4 , 5
82
73
model = ModelLSTM (LSTM (d_in => d_out), zeros (Float32, d_out), zeros (Float32, d_out))
83
74
x_nobatch = randn (Float32, d_in, len)
84
- @test test_gradients (model, x_nobatch; test_gpu= true , compare_finite_diff = false ,
85
- loss = (m, x) -> mean ( m (x)[ 1 ]) ) broken = :lstm_nobatch ∈ BROKEN_TESTS
75
+ @test test_gradients (model, x_nobatch; test_gpu= true ,
76
+ compare_finite_diff = false ) broken = :lstm_nobatch ∈ BROKEN_TESTS
86
77
x = randn (Float32, d_in, len, batch_size)
87
- @test test_gradients (model, x; test_gpu= true , compare_finite_diff = false ,
88
- loss = (m, x) -> mean ( m (x)[ 1 ]) ) broken = :lstm ∈ BROKEN_TESTS
78
+ @test test_gradients (model, x; test_gpu= true ,
79
+ compare_finite_diff = false ) broken = :lstm ∈ BROKEN_TESTS
89
80
end
90
81
91
82
@testset " GRUCell" begin
92
- function loss (r, x, h)
93
- y = []
94
- for x_t in x
95
- h = r (x_t, h)
96
- y = vcat (y, [h])
97
- end
98
- y = stack (y, dims= 2 ) # [D, L] or [D, L, B]
99
- return mean (y)
100
- end
101
-
102
83
d_in, d_out, len, batch_size = 2 , 3 , 4 , 5
103
84
r = GRUCell (d_in => d_out)
104
85
x = [randn (Float32, d_in, batch_size) for _ in 1 : len]
105
86
h = zeros (Float32, d_out)
106
87
@test test_gradients (r, x[1 ], h; test_gpu= true , compare_finite_diff= false ) broken = :grucell_single ∈ BROKEN_TESTS
107
- @test test_gradients (r, x, h; test_gpu= true , compare_finite_diff= false , loss) broken = :grucell_multiple ∈ BROKEN_TESTS
88
+ @test test_gradients (r, x, h; test_gpu= true ,
89
+ compare_finite_diff = false ,
90
+ loss = recurrent_cell_loss) broken = :grucell_multiple ∈ BROKEN_TESTS
108
91
end
109
92
110
93
@testset " GRU GPU AD" begin
@@ -120,28 +103,23 @@ end
120
103
d_in, d_out, len, batch_size = 2 , 3 , 4 , 5
121
104
model = ModelGRU (GRU (d_in => d_out), zeros (Float32, d_out))
122
105
x_nobatch = randn (Float32, d_in, len)
123
- @test test_gradients (model, x_nobatch; test_gpu= true , compare_finite_diff= false ) broken = :gru_nobatch ∈ BROKEN_TESTS
106
+ @test test_gradients (model, x_nobatch; test_gpu= true ,
107
+ compare_finite_diff= false ) broken = :gru_nobatch ∈ BROKEN_TESTS
124
108
x = randn (Float32, d_in, len, batch_size)
125
- @test test_gradients (model, x; test_gpu= true , compare_finite_diff= false ) broken = :gru ∈ BROKEN_TESTS
109
+ @test test_gradients (model, x; test_gpu= true ,
110
+ compare_finite_diff= false ) broken = :gru ∈ BROKEN_TESTS
126
111
end
127
112
128
113
@testset " GRUv3Cell GPU AD" begin
129
- function loss (r, x, h)
130
- y = []
131
- for x_t in x
132
- h = r (x_t, h)
133
- y = vcat (y, [h])
134
- end
135
- y = stack (y, dims= 2 ) # [D, L] or [D, L, B]
136
- return mean (y)
137
- end
138
-
139
114
d_in, d_out, len, batch_size = 2 , 3 , 4 , 5
140
115
r = GRUv3Cell (d_in => d_out)
141
116
x = [randn (Float32, d_in, batch_size) for _ in 1 : len]
142
117
h = zeros (Float32, d_out)
143
- @test test_gradients (r, x[1 ], h; test_gpu= true , compare_finite_diff= false ) broken = :gruv3cell_single ∈ BROKEN_TESTS
144
- @test test_gradients (r, x, h; test_gpu= true , compare_finite_diff= false , loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS
118
+ @test test_gradients (r, x[1 ], h; test_gpu= true ,
119
+ compare_finite_diff= false ) broken = :gruv3cell_single ∈ BROKEN_TESTS
120
+ @test test_gradients (r, x, h; test_gpu= true ,
121
+ compare_finite_diff= false ,
122
+ loss = recurrent_cell_loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS
145
123
end
146
124
147
125
@testset " GRUv3 GPU AD" begin
157
135
d_in, d_out, len, batch_size = 2 , 3 , 4 , 5
158
136
model = ModelGRUv3 (GRUv3 (d_in => d_out), zeros (Float32, d_out))
159
137
x_nobatch = randn (Float32, d_in, len)
160
- @test test_gradients (model, x_nobatch; test_gpu= true , compare_finite_diff= false ) broken = :gruv3_nobatch ∈ BROKEN_TESTS
138
+ @test test_gradients (model, x_nobatch; test_gpu= true ,
139
+ compare_finite_diff= false ) broken = :gruv3_nobatch ∈ BROKEN_TESTS
161
140
x = randn (Float32, d_in, len, batch_size)
162
- @test test_gradients (model, x; test_gpu= true , compare_finite_diff= false ) broken = :gruv3 ∈ BROKEN_TESTS
141
+ @test test_gradients (model, x; test_gpu= true ,
142
+ compare_finite_diff= false ) broken = :gruv3 ∈ BROKEN_TESTS
163
143
end
0 commit comments