Skip to content

Commit 3e5b77e

Browse files
author
Avik Pal
committed
Add additional tests for Gradients
1 parent 353012f commit 3e5b77e

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

test/conv.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,28 @@ end
5353

5454
@test depthwiseconv(x, w, pad = 1)[:] == [4.0, 11.0, 18.0, 9.0, 18.0, 37.0, 47.0, 21.0, 36.0, 67.0, 77.0, 33.0, 14.0, 23.0, 26.0, 9.0, 80.0, 158.0, 173.0, 84.0, 164.0, 319.0, 345.0, 165.0, 206.0, 397.0, 423.0, 201.0, 96.0, 182.0, 193.0, 90.0]
5555

56-
# the correctness of the gradients have been verified by calling
56+
# the correctness of the gradients are being verified by calling
5757
# the corresponding counvolution gradients
5858

59+
dy = reshape(Float64[1:8;], 2,2,2,1)
60+
local z = ∇depthwiseconv_data(dy,x,w)
61+
for i in 1:2
62+
X = copy(x[:,:,i:i,:]);
63+
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]));
64+
DY = copy(dy[:,:,i:i,:]);
65+
res = ∇conv_data(DY,X,W)
66+
@test squeeze(z[:,:,i:i,:], (3,4)) == squeeze(res, (3,4))
67+
end
68+
69+
z = ∇depthwiseconv_filter(dy, x, w)
70+
for i in 1:2
71+
X = copy(x[:,:,i:i,:]);
72+
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]))
73+
DY = copy(dy[:,:,i:i,:])
74+
res = ∇conv_filter(DY,X,W)
75+
@test squeeze(z[:,:,:,i:i], (3,4)) == squeeze(res, (3,4))
76+
end
77+
5978
@test size(∇depthwiseconv_filter(rand(2,2,2,1), x, w)) == size(w)
6079
@test size(∇depthwiseconv_data(rand(2,2,2,1), x, w)) == size(x)
6180

0 commit comments

Comments
 (0)