44
44
45
45
function col2im_2d! {T} (col:: AbstractArray{T,2} , img:: AbstractArray{T,3} , width:: Int , height:: Int ,
46
46
channels:: Int , kernel_w:: Int , kernel_h:: Int , pad_w:: Int , pad_h:: Int , stride_w:: Int ,
47
- stride_h:: Int , dil_h :: Int , dil_w :: Int , mode:: Int )
47
+ stride_h:: Int , dil_w :: Int , dil_h :: Int , mode:: Int )
48
48
49
49
height_col = div (height + 2 pad_h - (kernel_h - 1 ) * dil_h - 1 , stride_h) + 1
50
50
width_col = div (width + 2 pad_w - (kernel_w - 1 ) * dil_w - 1 , stride_w) + 1
@@ -159,26 +159,25 @@ function dilation_dims(w, dilation = 1)
159
159
end
160
160
end
161
161
162
- function im2col_dims (w,y,dilation = 1 )
162
+ function im2col_dims (w,y)
163
163
N = ndims (y)
164
- dil = dilation_dims (w, dilation)
165
164
r,c = 1 ,1
166
165
for i= 1 : N- 2
167
166
r *= size (y,i)
168
- c *= dil[i]
167
+ c *= size (w,i)
169
168
end
170
- c *= dil[ N- 1 ]
169
+ c *= size (w, N- 1 )
171
170
return (r, c)
172
171
end
173
172
174
173
function conv2d! {T} (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
175
174
padding= 0 , stride= 1 , dilation= 1 , mode= 0 , alpha= T (1 ))
176
175
if mode != 0 && mode != 1 ; throw (ArgumentError (" conv2d only supports mode=0 or 1." )); end
177
176
Wx,Hx,Cx,Nx = size (x)
178
- Ww,Hw,C1,C2 = dilation_dims (w, dilation )
177
+ Ww,Hw,C1,C2 = size (w )
179
178
if Cx!= C1; throw (DimensionMismatch ()); end
180
179
Wy,Hy,Cy,Ny = size (y)
181
- x2dims = im2col_dims (w,y,dilation )
180
+ x2dims = im2col_dims (w,y)
182
181
x2 = similar (x, x2dims)
183
182
(p1,p2) = psize (padding,x)
184
183
(s1,s2) = psize (stride,x)
@@ -197,11 +196,11 @@ function conv2d_grad_w!{T}(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abs
197
196
padding= 0 , stride= 1 , dilation= 1 , mode= 0 , alpha= 1 )
198
197
# dw = x'*dy
199
198
Wx,Hx,Cx,Nx = size (x)
200
- Ww,Hw,C1,C2 = dilation_dims (w, dilation )
199
+ Ww,Hw,C1,C2 = size (w )
201
200
Wy,Hy,Cy,Ny = size (dy)
202
201
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
203
202
# @assert Cx==C1 && Cy==C2 && Ny==Nx
204
- x2dims = im2col_dims (w,dy,dilation )
203
+ x2dims = im2col_dims (w,dy)
205
204
x2 = similar (x, x2dims)
206
205
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
207
206
Y,M,N,K = Wy* Hy* Cy,Ww* Hw* Cx,Cy,Wy* Hy
@@ -222,11 +221,11 @@ function conv2d_grad_x!{T}(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abs
222
221
padding= 0 , stride= 1 , dilation= 1 , mode= 0 , alpha= 1 )
223
222
# dx = dy*w'
224
223
Wx,Hx,Cx,Nx = size (x)
225
- Ww,Hw,C1,C2 = dilation_dims (w, dilation )
224
+ Ww,Hw,C1,C2 = size (w )
226
225
Wy,Hy,Cy,Ny = size (dy)
227
226
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
228
227
@assert Cx== C1 && Cy== C2 && Ny== Nx
229
- x2dims = im2col_dims (w,dy,dilation )
228
+ x2dims = im2col_dims (w,dy)
230
229
x2 = similar (x, x2dims)
231
230
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
232
231
Y,M,N,K = Wy* Hy* Cy,Wy* Hy,Ww* Hw* Cx,Cy
@@ -266,11 +265,11 @@ function conv3d!{T}(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArr
266
265
padding= 0 , stride= 1 , dilation = 1 , mode= 0 , alpha= T (1 ))
267
266
if mode != 0 && mode != 1 ; throw (ArgumentError (" conv3d only supports mode=0 or 1." )); end
268
267
Wx,Hx,Dx,Cx,Nx = size (x)
269
- Ww,Hw,Dw,C1,C2 = dilation_dims (w, dilation )
268
+ Ww,Hw,Dw,C1,C2 = size (w )
270
269
if Cx!= C1; throw (DimensionMismatch ()); end
271
270
Wy,Hy,Dy,Cy,Ny = size (y)
272
271
# @assert Cy==C2 && Ny==Nx
273
- x2dims = im2col_dims (w,y,dilation )
272
+ x2dims = im2col_dims (w,y)
274
273
x2 = similar (x, x2dims)
275
274
(p1,p2,p3) = psize (padding,x)
276
275
(s1,s2,s3) = psize (stride,x)
@@ -290,11 +289,11 @@ function conv3d_grad_w!{T}(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abs
290
289
padding= 0 , stride= 1 , dilation = 1 , mode= 0 , alpha= 1 )
291
290
# dw = x'*dy
292
291
Wx,Hx,Dx,Cx,Nx = size (x)
293
- Ww,Hw,Dw,C1,C2 = dilation_dims (w, dilation )
292
+ Ww,Hw,Dw,C1,C2 = size (w )
294
293
Wy,Hy,Dy,Cy,Ny = size (dy)
295
294
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
296
295
# @assert Cx==C1 && Cy==C2 && Ny==Nx
297
- x2dims = im2col_dims (w,dy,dilation )
296
+ x2dims = im2col_dims (w,dy)
298
297
x2 = similar (x, x2dims)
299
298
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
300
299
Y,M,N,K = Wy* Hy* Dy* Cy,Ww* Hw* Dw* Cx,Cy,Wy* Hy* Dy
@@ -315,11 +314,11 @@ function conv3d_grad_x!{T}(dx::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abs
315
314
padding= 0 , stride= 1 , dilation = 1 , mode= 0 , alpha= 1 )
316
315
# dx = dy*w'
317
316
Wx,Hx,Dx,Cx,Nx = size (x)
318
- Ww,Hw,Dw,C1,C2 = dilation_dims (w, dilation )
317
+ Ww,Hw,Dw,C1,C2 = size (w )
319
318
Wy,Hy,Dy,Cy,Ny = size (dy)
320
319
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
321
320
@assert Cx== C1 && Cy== C2 && Ny== Nx
322
- x2dims = im2col_dims (w,dy,dilation )
321
+ x2dims = im2col_dims (w,dy)
323
322
x2 = similar (x, x2dims)
324
323
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
325
324
Y,M,N,K = Wy* Hy* Dy* Cy,Wy* Hy* Dy,Ww* Hw* Dw* Cx,Cy
0 commit comments