@@ -27,6 +27,15 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
27
27
48 98 ;
28
28
58 108 ;
29
29
68 118. ]
30
+
31
+ # NaN tests for dilation forward pass
32
+
33
+ ys = []
34
+ for idx in 1 : 1000
35
+ push! (ys, conv (x, w; dilation= 2 ))
36
+ end
37
+ @test ! any ([any (isnan .(ys[idx])) for idx in 1 : 1000 ])
38
+
30
39
# for gradients, check only size
31
40
# correctness of gradients is cross-checked with CUDNN.jl
32
41
# (it's assumed convolution code won't change often)
@@ -39,6 +48,23 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
39
48
@test size (y) == (3 , 2 , 1 , 1 )
40
49
@test size (∇conv_filter (y, x, w; stride= 2 , pad= 1 , dilation= 2 )) == size (w)
41
50
@test size (∇conv_data (y, x, w; stride= 2 , pad= 1 , dilation= 2 )) == size (x)
51
+
52
+ # NaN tests for dilation backward pass: filters
53
+ dy = randn (size (ys[1 ]))
54
+ dws = []
55
+ for idx in 1 : 1000
56
+ push! (dws, ∇conv_filter (dy, x, w; dilation= 2 ))
57
+ end
58
+
59
+ # NaN tests for dilation backward pass: input
60
+ dxs = []
61
+ for idx in 1 : 1000
62
+ push! (dxs, ∇conv_data (dy, x, w; dilation= 2 ))
63
+ end
64
+
65
+ @test ! any ([any (isnan .(dws[idx])) for idx in 1 : 1000 ])
66
+ @test ! any ([any (isnan .(dxs[idx])) for idx in 1 : 1000 ])
67
+
42
68
end
43
69
44
70
@@ -123,13 +149,37 @@ end
123
149
680 860.
124
150
]
125
151
152
+ # NaN tests for dilation forward pass
153
+
154
+ ys = []
155
+ for idx in 1 : 1000
156
+ push! (ys, conv (x, w; dilation= 2 ))
157
+ end
158
+ @test ! any ([any (isnan .(ys[idx])) for idx in 1 : 1000 ])
159
+
126
160
# for gradients, check only size
127
161
# correctness of gradients is cross-checked with CUDNN.jl
128
162
# (it's assumed convolution code won't change often)
129
163
130
164
@test size (∇conv_filter (reshape (rand (4 ,3 ,2 ), 4 , 3 , 2 , 1 , 1 ), x, w)) == size (w)
131
165
@test size (∇conv_data (reshape (rand (4 ,3 ,2 ), 4 , 3 , 2 , 1 , 1 ), x, w)) == size (x)
132
166
167
+ # NaN tests for dilation backward pass: filters
168
+ dy = randn (size (ys[1 ]))
169
+ dws = []
170
+ for idx in 1 : 1000
171
+ push! (dws, ∇conv_filter (dy, x, w; dilation= 2 ))
172
+ end
173
+
174
+ # NaN tests for dilation backward pass: input
175
+ dxs = []
176
+ for idx in 1 : 1000
177
+ push! (dxs, ∇conv_data (dy, x, w; dilation= 2 ))
178
+ end
179
+
180
+ @test ! any ([any (isnan .(dws[idx])) for idx in 1 : 1000 ])
181
+ @test ! any ([any (isnan .(dxs[idx])) for idx in 1 : 1000 ])
182
+
133
183
end
134
184
135
185
0 commit comments