Skip to content

Commit 44b5148

Browse files
committed
tests for NaN added
1 parent 9d71cd4 commit 44b5148

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test/conv.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
2727
48 98;
2828
58 108;
2929
68 118.]
30+
31+
# NaN tests for dilation forward pass
32+
33+
ys = []
34+
for idx in 1:1000
35+
push!(ys, conv(x, w; dilation=2))
36+
end
37+
@test !any([any(isnan.(ys[idx])) for idx in 1:1000])
38+
3039
# for gradients, check only size
3140
# correctness of gradients is cross-checked with CUDNN.jl
3241
# (it's assumed convolution code won't change often)
@@ -39,6 +48,23 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
3948
@test size(y) == (3, 2, 1, 1)
4049
@test size(∇conv_filter(y, x, w; stride=2, pad=1, dilation=2)) == size(w)
4150
@test size(∇conv_data(y, x, w; stride=2, pad=1, dilation=2)) == size(x)
51+
52+
# NaN tests for dilation backward pass: filters
53+
dy = randn(size(ys[1]))
54+
dws = []
55+
for idx in 1:1000
56+
push!(dws, ∇conv_filter(dy, x, w; dilation=2))
57+
end
58+
59+
# NaN tests for dilation backward pass: input
60+
dxs = []
61+
for idx in 1:1000
62+
push!(dxs, ∇conv_data(dy, x, w; dilation=2))
63+
end
64+
65+
@test !any([any(isnan.(dws[idx])) for idx in 1:1000])
66+
@test !any([any(isnan.(dxs[idx])) for idx in 1:1000])
67+
4268
end
4369

4470

@@ -123,13 +149,37 @@ end
123149
680 860.
124150
]
125151

152+
# NaN tests for dilation forward pass
153+
154+
ys = []
155+
for idx in 1:1000
156+
push!(ys, conv(x, w; dilation=2))
157+
end
158+
@test !any([any(isnan.(ys[idx])) for idx in 1:1000])
159+
126160
# for gradients, check only size
127161
# correctness of gradients is cross-checked with CUDNN.jl
128162
# (it's assumed convolution code won't change often)
129163

130164
@test size(∇conv_filter(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x, w)) == size(w)
131165
@test size(∇conv_data(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x, w)) == size(x)
132166

167+
# NaN tests for dilation backward pass: filters
168+
dy = randn(size(ys[1]))
169+
dws = []
170+
for idx in 1:1000
171+
push!(dws, ∇conv_filter(dy, x, w; dilation=2))
172+
end
173+
174+
# NaN tests for dilation backward pass: input
175+
dxs = []
176+
for idx in 1:1000
177+
push!(dxs, ∇conv_data(dy, x, w; dilation=2))
178+
end
179+
180+
@test !any([any(isnan.(dws[idx])) for idx in 1:1000])
181+
@test !any([any(isnan.(dxs[idx])) for idx in 1:1000])
182+
133183
end
134184

135185

0 commit comments

Comments
 (0)