Skip to content

Commit 353012f

Browse files
author
Avik Pal
committed
Add tests for backward pass
1 parent 5fcd291 commit 353012f

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

test/conv.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwiseconv
1+
using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwiseconv, ∇depthwiseconv_filter, ∇depthwiseconv_data
22

33
@testset "conv2d" begin
44
x = reshape(Float64[1:20;], 5, 4, 1, 1)
@@ -52,6 +52,18 @@ end
5252
@test depthwiseconv(x, w, stride = 2)[:] == [37.0, 319.0]
5353

5454
@test depthwiseconv(x, w, pad = 1)[:] == [4.0, 11.0, 18.0, 9.0, 18.0, 37.0, 47.0, 21.0, 36.0, 67.0, 77.0, 33.0, 14.0, 23.0, 26.0, 9.0, 80.0, 158.0, 173.0, 84.0, 164.0, 319.0, 345.0, 165.0, 206.0, 397.0, 423.0, 201.0, 96.0, 182.0, 193.0, 90.0]
55+
56+
# the correctness of the gradients have been verified by calling
57+
# the corresponding counvolution gradients
58+
59+
@test size(∇depthwiseconv_filter(rand(2,2,2,1), x, w)) == size(w)
60+
@test size(∇depthwiseconv_data(rand(2,2,2,1), x, w)) == size(x)
61+
62+
# Test for the stride/pad for backward pass
63+
y = depthwiseconv(x,w,stride=2,pad=1)
64+
@test size(y) == (2,2,2,1)
65+
@test size(∇depthwiseconv_filter(rand(size(y)), x, w, stride=2, pad=1)) == size(w)
66+
@test size(∇depthwiseconv_data(rand(size(y)), x, w, stride=2, pad=1)) == size(x)
5567
end
5668

5769
@testset "maxpool2d" begin

0 commit comments

Comments
 (0)