Skip to content

Commit 1ed81ee

Browse files
authored
Merge pull request #45 from tejank10/dilation
Fixed NaN error in dilation
2 parents 513ff01 + 44b5148 commit 1ed81ee

File tree

2 files changed

+66
-17
lines changed

2 files changed

+66
-17
lines changed

src/impl/conv.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444

4545
function col2im_2d!{T}(col::AbstractArray{T,2}, img::AbstractArray{T,3}, width::Int, height::Int,
4646
channels::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int,
47-
stride_h::Int, dil_h::Int, dil_w::Int, mode::Int)
47+
stride_h::Int, dil_w::Int, dil_h::Int, mode::Int)
4848

4949
height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
5050
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
@@ -159,26 +159,25 @@ function dilation_dims(w, dilation = 1)
159159
end
160160
end
161161

162-
function im2col_dims(w,y,dilation=1)
162+
function im2col_dims(w,y)
163163
N = ndims(y)
164-
dil = dilation_dims(w, dilation)
165164
r,c = 1,1
166165
for i=1:N-2
167166
r *= size(y,i)
168-
c *= dil[i]
167+
c *= size(w,i)
169168
end
170-
c *= dil[N-1]
169+
c *= size(w,N-1)
171170
return (r, c)
172171
end
173172

174173
function conv2d!{T}(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
175174
padding=0, stride=1, dilation=1, mode=0, alpha=T(1))
176175
if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
177176
Wx,Hx,Cx,Nx = size(x)
178-
Ww,Hw,C1,C2 = dilation_dims(w, dilation)
177+
Ww,Hw,C1,C2 = size(w)
179178
if Cx!=C1; throw(DimensionMismatch()); end
180179
Wy,Hy,Cy,Ny = size(y)
181-
x2dims = im2col_dims(w,y,dilation)
180+
x2dims = im2col_dims(w,y)
182181
x2 = similar(x, x2dims)
183182
(p1,p2) = psize(padding,x)
184183
(s1,s2) = psize(stride,x)
@@ -197,11 +196,11 @@ function conv2d_grad_w!{T}(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abs
197196
padding=0, stride=1, dilation=1, mode=0, alpha=1)
198197
# dw = x'*dy
199198
Wx,Hx,Cx,Nx = size(x)
200-
Ww,Hw,C1,C2 = dilation_dims(w, dilation)
199+
Ww,Hw,C1,C2 = size(w)
201200
Wy,Hy,Cy,Ny = size(dy)
202201
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
203202
# @assert Cx==C1 && Cy==C2 && Ny==Nx
204-
x2dims = im2col_dims(w,dy,dilation)
203+
x2dims = im2col_dims(w,dy)
205204
x2 = similar(x, x2dims)
206205
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
207206
Y,M,N,K = Wy*Hy*Cy,Ww*Hw*Cx,Cy,Wy*Hy
@@ -222,11 +221,11 @@ function conv2d_grad_x!{T}(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abs
222221
padding=0, stride=1, dilation=1, mode=0, alpha=1)
223222
# dx = dy*w'
224223
Wx,Hx,Cx,Nx = size(x)
225-
Ww,Hw,C1,C2 = dilation_dims(w, dilation)
224+
Ww,Hw,C1,C2 = size(w)
226225
Wy,Hy,Cy,Ny = size(dy)
227226
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
228227
@assert Cx==C1 && Cy==C2 && Ny==Nx
229-
x2dims = im2col_dims(w,dy,dilation)
228+
x2dims = im2col_dims(w,dy)
230229
x2 = similar(x, x2dims)
231230
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
232231
Y,M,N,K = Wy*Hy*Cy,Wy*Hy,Ww*Hw*Cx,Cy
@@ -266,11 +265,11 @@ function conv3d!{T}(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArr
266265
padding=0, stride=1, dilation = 1, mode=0, alpha=T(1))
267266
if mode != 0 && mode != 1; throw(ArgumentError("conv3d only supports mode=0 or 1.")); end
268267
Wx,Hx,Dx,Cx,Nx = size(x)
269-
Ww,Hw,Dw,C1,C2 = dilation_dims(w, dilation)
268+
Ww,Hw,Dw,C1,C2 = size(w)
270269
if Cx!=C1; throw(DimensionMismatch()); end
271270
Wy,Hy,Dy,Cy,Ny = size(y)
272271
# @assert Cy==C2 && Ny==Nx
273-
x2dims = im2col_dims(w,y,dilation)
272+
x2dims = im2col_dims(w,y)
274273
x2 = similar(x, x2dims)
275274
(p1,p2,p3) = psize(padding,x)
276275
(s1,s2,s3) = psize(stride,x)
@@ -290,11 +289,11 @@ function conv3d_grad_w!{T}(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abs
290289
padding=0, stride=1, dilation = 1, mode=0, alpha=1)
291290
# dw = x'*dy
292291
Wx,Hx,Dx,Cx,Nx = size(x)
293-
Ww,Hw,Dw,C1,C2 = dilation_dims(w, dilation)
292+
Ww,Hw,Dw,C1,C2 = size(w)
294293
Wy,Hy,Dy,Cy,Ny = size(dy)
295294
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
296295
# @assert Cx==C1 && Cy==C2 && Ny==Nx
297-
x2dims = im2col_dims(w,dy,dilation)
296+
x2dims = im2col_dims(w,dy)
298297
x2 = similar(x, x2dims)
299298
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
300299
Y,M,N,K = Wy*Hy*Dy*Cy,Ww*Hw*Dw*Cx,Cy,Wy*Hy*Dy
@@ -315,11 +314,11 @@ function conv3d_grad_x!{T}(dx::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abs
315314
padding=0, stride=1, dilation = 1, mode=0, alpha=1)
316315
# dx = dy*w'
317316
Wx,Hx,Dx,Cx,Nx = size(x)
318-
Ww,Hw,Dw,C1,C2 = dilation_dims(w, dilation)
317+
Ww,Hw,Dw,C1,C2 = size(w)
319318
Wy,Hy,Dy,Cy,Ny = size(dy)
320319
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
321320
@assert Cx==C1 && Cy==C2 && Ny==Nx
322-
x2dims = im2col_dims(w,dy,dilation)
321+
x2dims = im2col_dims(w,dy)
323322
x2 = similar(x, x2dims)
324323
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
325324
Y,M,N,K = Wy*Hy*Dy*Cy,Wy*Hy*Dy,Ww*Hw*Dw*Cx,Cy

test/conv.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
2727
48 98;
2828
58 108;
2929
68 118.]
30+
31+
# NaN tests for dilation forward pass
32+
33+
ys = []
34+
for idx in 1:1000
35+
push!(ys, conv(x, w; dilation=2))
36+
end
37+
@test !any([any(isnan.(ys[idx])) for idx in 1:1000])
38+
3039
# for gradients, check only size
3140
# correctness of gradients is cross-checked with CUDNN.jl
3241
# (it's assumed convolution code won't change often)
@@ -39,6 +48,23 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
3948
@test size(y) == (3, 2, 1, 1)
4049
@test size(∇conv_filter(y, x, w; stride=2, pad=1, dilation=2)) == size(w)
4150
@test size(∇conv_data(y, x, w; stride=2, pad=1, dilation=2)) == size(x)
51+
52+
# NaN tests for dilation backward pass: filters
53+
dy = randn(size(ys[1]))
54+
dws = []
55+
for idx in 1:1000
56+
push!(dws, ∇conv_filter(dy, x, w; dilation=2))
57+
end
58+
59+
# NaN tests for dilation backward pass: input
60+
dxs = []
61+
for idx in 1:1000
62+
push!(dxs, ∇conv_data(dy, x, w; dilation=2))
63+
end
64+
65+
@test !any([any(isnan.(dws[idx])) for idx in 1:1000])
66+
@test !any([any(isnan.(dxs[idx])) for idx in 1:1000])
67+
4268
end
4369

4470

@@ -123,13 +149,37 @@ end
123149
680 860.
124150
]
125151

152+
# NaN tests for dilation forward pass
153+
154+
ys = []
155+
for idx in 1:1000
156+
push!(ys, conv(x, w; dilation=2))
157+
end
158+
@test !any([any(isnan.(ys[idx])) for idx in 1:1000])
159+
126160
# for gradients, check only size
127161
# correctness of gradients is cross-checked with CUDNN.jl
128162
# (it's assumed convolution code won't change often)
129163

130164
@test size(∇conv_filter(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x, w)) == size(w)
131165
@test size(∇conv_data(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x, w)) == size(x)
132166

167+
# NaN tests for dilation backward pass: filters
168+
dy = randn(size(ys[1]))
169+
dws = []
170+
for idx in 1:1000
171+
push!(dws, ∇conv_filter(dy, x, w; dilation=2))
172+
end
173+
174+
# NaN tests for dilation backward pass: input
175+
dxs = []
176+
for idx in 1:1000
177+
push!(dxs, ∇conv_data(dy, x, w; dilation=2))
178+
end
179+
180+
@test !any([any(isnan.(dws[idx])) for idx in 1:1000])
181+
@test !any([any(isnan.(dxs[idx])) for idx in 1:1000])
182+
133183
end
134184

135185

0 commit comments

Comments
 (0)