Skip to content

Commit 97800cc

Browse files
authored
Merge pull request #60 from maleadt/tb/1.0
Fixes for 0.7/1.0
2 parents ea23e3f + a3e6457 commit 97800cc

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

src/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,
101101
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
102102

103103
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
104-
∇depthwiseconv_data!(zeros(x), dy, x, w; pad = pad, stride = stride)
104+
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
105105

106106
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
107-
∇depthwiseconv_filter!(zeros(w), dy, x, w; pad = pad, stride = stride)
107+
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
108108

109109
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
110110
pad = 0, stride = 1) where T =

test/conv.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
44
x = reshape(Float64[1:20;], 5, 4, 1, 1)
55
w = reshape(Float64[1:4;], 2, 2, 1, 1)
66

7-
@test squeeze(conv(x, w), dims = (3,4)) == [
7+
@test dropdims(conv(x, w), dims = (3,4)) == [
88
29 79 129;
99
39 89 139;
1010
49 99 149;
1111
59 109 159.]
1212

13-
@test squeeze(conv(x, w; stride=2), dims = (3,4)) == [
13+
@test dropdims(conv(x, w; stride=2), dims = (3,4)) == [
1414
29 129;
1515
49 149.]
1616

17-
@test squeeze(conv(x, w; pad=1), dims = (3,4)) == [
17+
@test dropdims(conv(x, w; pad=1), dims = (3,4)) == [
1818
1.0 9.0 29.0 49.0 48.0;
1919
4.0 29.0 79.0 129.0 115.0;
2020
7.0 39.0 89.0 139.0 122.0;
@@ -23,7 +23,7 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
2323
10.0 40.0 70.0 100.0 80.0
2424
]
2525

26-
@test squeeze(conv(x, w; dilation=2), dims = (3,4)) == [
26+
@test dropdims(conv(x, w; dilation=2), dims = (3,4)) == [
2727
48 98;
2828
58 108;
2929
68 118.]
@@ -89,7 +89,7 @@ end
8989
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]));
9090
DY = copy(dy[:,:,2i-1:2i,:]);
9191
res = ∇conv_data(DY,X,W)
92-
@test squeeze(z[:,:,i:i,:], (3,4)) == squeeze(res, (3,4))
92+
@test dropdims(z[:,:,i:i,:], dims=(3,4)) == dropdims(res, dims=(3,4))
9393
end
9494

9595
z = ∇depthwiseconv_filter(dy, x, w)
@@ -98,7 +98,7 @@ end
9898
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]))
9999
DY = copy(dy[:,:,2i-1:2i,:])
100100
res = ∇conv_filter(DY,X,W)
101-
@test squeeze(z[:,:,:,i:i], (4)) == squeeze(res, (3))
101+
@test dropdims(z[:,:,:,i:i]; dims=(4)) == dropdims(res; dims=(3))
102102
end
103103

104104
@test size(∇depthwiseconv_filter(rand(2,2,4,1), x, w)) == size(w)
@@ -107,17 +107,17 @@ end
107107
# Test for the stride/pad for backward pass
108108
y = depthwiseconv(x,w,stride=2,pad=1)
109109
@test size(y) == (2,2,4,1)
110-
@test size(∇depthwiseconv_filter(rand(size(y)), x, w, stride=2, pad=1)) == size(w)
111-
@test size(∇depthwiseconv_data(rand(size(y)), x, w, stride=2, pad=1)) == size(x)
110+
@test size(∇depthwiseconv_filter(rand(Float64, size(y)), x, w, stride=2, pad=1)) == size(w)
111+
@test size(∇depthwiseconv_data(rand(Float64, size(y)), x, w, stride=2, pad=1)) == size(x)
112112
end
113113

114114
@testset "maxpool2d" begin
115115

116116
x = reshape(Float64[1:20;], 5, 4, 1, 1)
117117

118-
@test squeeze(maxpool(x, (2,2)), dims = (3,4)) == [7 17; 9 19]
119-
@test squeeze(maxpool(x, (2,2); stride=(2,2)), dims = (3,4)) == [7 17; 9 19]
120-
@test squeeze(maxpool(x, (2,2); pad=(1,1)), dims = (3,4)) == [
118+
@test dropdims(maxpool(x, (2,2)), dims = (3,4)) == [7 17; 9 19]
119+
@test dropdims(maxpool(x, (2,2); stride=(2,2)), dims = (3,4)) == [7 17; 9 19]
120+
@test dropdims(maxpool(x, (2,2); pad=(1,1)), dims = (3,4)) == [
121121
1.0 11.0 16.0;
122122
3.0 13.0 18.0;
123123
5.0 15.0 20.0;
@@ -149,9 +149,9 @@ end
149149
1078.0 1258.0 1438.0;
150150
1114.0 1294.0 1474.0;
151151
1150.0 1330.0 1510.0]
152-
@test squeeze(conv(x, w), dims = (4,5)) == res
152+
@test dropdims(conv(x, w), dims = (4,5)) == res
153153

154-
@test squeeze(conv(x, w; stride=2), dims = (3,4,5)) == [
154+
@test dropdims(conv(x, w; stride=2), dims = (3,4,5)) == [
155155
322.0 682.0;
156156
394.0 754.0]
157157

@@ -184,9 +184,9 @@ end
184184
478.0 1185.0 1315.0 1445.0 877.0;
185185
489.0 1211.0 1341.0 1471.0 892.0;
186186
270.0 660.0 730.0 800.0 480.0]
187-
@test squeeze(conv(x, w; pad=1), dims = (4,5)) == res
187+
@test dropdims(conv(x, w; pad=1), dims = (4,5)) == res
188188

189-
@test squeeze(conv(x, w; dilation=2), dims = (3,4,5)) == [
189+
@test dropdims(conv(x, w; dilation=2), dims = (3,4,5)) == [
190190
608 788;
191191
644 824;
192192
680 860.
@@ -230,8 +230,8 @@ end
230230

231231
x = reshape(Float64[1:60;], 5, 4, 3, 1, 1)
232232

233-
@test squeeze(maxpool(x, (2,2,2)), dims = (3,4,5)) == [27 37; 29 39.]
234-
@test squeeze(maxpool(x, (2,2,2); stride=(2,2,2)), dims = (3,4,5)) == [27 37; 29 39.]
233+
@test dropdims(maxpool(x, (2,2,2)), dims = (3,4,5)) == [27 37; 29 39.]
234+
@test dropdims(maxpool(x, (2,2,2); stride=(2,2,2)), dims = (3,4,5)) == [27 37; 29 39.]
235235
res = zeros(3,3,2)
236236
res[:, :, 1] = [
237237
1.0 11.0 16.0;
@@ -241,7 +241,7 @@ end
241241
41.0 51.0 56.0;
242242
43.0 53.0 58.0;
243243
45.0 55.0 60.0]
244-
@test squeeze(maxpool(x, (2,2,2), pad=(1,1,1)), dims = (4,5)) == res
244+
@test dropdims(maxpool(x, (2,2,2), pad=(1,1,1)), dims = (4,5)) == res
245245

246246
# for gradients, check only size
247247
# correctness of gradients is cross-checked with CUDNN.jl

0 commit comments

Comments
 (0)