@@ -170,6 +170,92 @@ function im2col_dims(w,y)
170
170
return (r, c)
171
171
end
172
172
173
+ function im2col_dims (w:: NTuple{4, Int} , y)
174
+ N = ndims (y)
175
+ r,c = 1 ,1
176
+ for i= 1 : N- 2
177
+ r *= size (y,i)
178
+ c *= w[i]
179
+ end
180
+ c *= w[N- 1 ]
181
+ return (r, c)
182
+ end
183
+
184
+ function depthwiseconv2d! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
185
+ padding = 0 , stride = 1 , mode = 1 , alpha = T (1 )) where T
186
+ Wx,Hx,Cx,Nx = size (x)
187
+ Ww,Hw,Cm,Cw = size (w) # Cm = Channel Multiplier
188
+ @assert Cx == Cw DimensionMismatch ()
189
+ Wy,Hy,Cy,Ny = size (y) # Cy = Cw * Cm
190
+ dims_w = (Ww,Hw,Cw,Cm* Cw)
191
+ x2dims = im2col_dims (dims_w,y)
192
+ x2 = similar (x, x2dims)
193
+ (p1,p2) = psize (padding,x)
194
+ (s1,s2) = psize (stride,x)
195
+ M,N,K,Y = Wy* Hy,Cm,Ww* Hw,Wy* Hy* Cm
196
+ yidx = 1
197
+ @inbounds for i in 1 : Nx
198
+ im2col2d! (dims_w, x, x2, i, p1, p2, s1, s2, mode)
199
+ @inbounds for j in 1 : Cx
200
+ gemm! (' N' ,' N' ,M,N,K,alpha,pointer (x2,(j- 1 )* M* K+ 1 ),pointer (w,(j- 1 )* K* N+ 1 ),T (0 ),pointer (y,yidx))
201
+ yidx += Y
202
+ end
203
+ end
204
+ return y
205
+ end
206
+
207
+ function depthwiseconv2d_grad_w! (dw:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} , dy:: AbstractArray{T,4} ;
208
+ padding= 0 , stride= 1 , mode= 0 , alpha= 1 ) where T
209
+ Wx,Hx,Cx,Nx = size (x)
210
+ Ww,Hw,Cm,Cw = size (w) # Cm = Channel Multiplier
211
+ @assert Cx == Cw DimensionMismatch ()
212
+ Wy,Hy,Cy,Ny = size (dy) # Cy = Cw * Cm
213
+ @assert Cy == Cw * Cm DimensionMismatch ()
214
+ dims_w = (Ww,Hw,Cw,Cm* Cw)
215
+ x2dims = im2col_dims (dims_w,dy)
216
+ x2 = similar (x, x2dims)
217
+ (p1,p2) = psize (padding,x)
218
+ (s1,s2) = psize (stride,x)
219
+ M,N,K,Y,W = Ww* Hw,Cm,Wy* Hy,Wy* Hy* Cm* Cx,Ww* Hw* Cm
220
+ alpha,beta = T (alpha),T (1 )
221
+ dyidx = 1
222
+ @inbounds for i in 1 : Nx
223
+ im2col2d! (dims_w, x, x2, i, p1, p2, s1, s2, mode)
224
+ dwidx = 1
225
+ @inbounds for j in 1 : Cx
226
+ gemm! (' T' ,' T' ,M,N,K,alpha,pointer (x2,(j- 1 )* M* K+ 1 ),pointer (dy,dyidx+ (j- 1 )* K* N),beta,pointer (dw,dwidx))
227
+ dwidx += W
228
+ end
229
+ dyidx += Y
230
+ end
231
+ return dw
232
+ end
233
+
234
+ function depthwiseconv2d_grad_x! (dx:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} , dy:: AbstractArray{T,4} ;
235
+ padding= 0 , stride= 1 , mode= 0 , alpha= 1 ) where T
236
+ Wx,Hx,Cx,Nx = size (x)
237
+ Ww,Hw,Cm,Cw = size (w) # Cm = Channel Multiplier
238
+ @assert Cx == Cw DimensionMismatch ()
239
+ Wy,Hy,Cy,Ny = size (dy) # Cy = Cw * Cm
240
+ @assert Cy == Cw * Cm DimensionMismatch ()
241
+ dims_w = (Ww,Hw,Cw,Cm* Cw)
242
+ x2dims = im2col_dims (dims_w,dy)
243
+ x2 = similar (x, x2dims)
244
+ M,N,K,Y,W = Wy* Hy,Ww* Hw,Cm,Wy* Hy* Cm* Cx,Ww* Hw* Cm
245
+ alpha,beta = T (alpha),T (0 )
246
+ (p1,p2) = psize (padding,x)
247
+ (s1,s2) = psize (stride,x)
248
+ dyidx = 1
249
+ @inbounds for i in 1 : Nx
250
+ @inbounds for j in 1 : Cx
251
+ gemm! (' N' ,' T' ,M,N,K,alpha,pointer (dy,dyidx+ (j- 1 )* K* M),pointer (w,(j- 1 )* K* N+ 1 ),beta,pointer (x2,(j- 1 )* M* N+ 1 ))
252
+ end
253
+ col2im2d! (dims_w,dx,x2,i,p1,p2,s1,s2,mode)
254
+ dyidx += Y
255
+ end
256
+ return dx
257
+ end
258
+
173
259
function conv2d! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
174
260
padding= 0 , stride= 1 , dilation= 1 , mode= 0 , alpha= T (1 )) where T
175
261
if mode != 0 && mode != 1 ; throw (ArgumentError (" conv2d only supports mode=0 or 1." )); end
@@ -242,6 +328,15 @@ function conv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstra
242
328
return dx
243
329
end
244
330
331
+ function im2col2d! (w:: NTuple{4,Int} , x:: AbstractArray{T,4} , x2:: AbstractArray{T,2} ,
332
+ n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , mode:: Int ) where T
333
+ Wx,Hx,Cx,Nx = size (x)
334
+ Ww,Hw,C1,C2 = w
335
+ xn = x[:, :, :, n]
336
+ im2col_2d! (xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1 ,1 ,mode)
337
+ return x2
338
+ end
339
+
245
340
function im2col2d! (w:: AbstractArray{T,4} , x:: AbstractArray{T,4} , x2:: AbstractArray{T,2} ,
246
341
n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , d1:: Int , d2:: Int , mode:: Int ) where T
247
342
Wx,Hx,Cx,Nx = size (x)
@@ -251,6 +346,16 @@ function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArr
251
346
return x2
252
347
end
253
348
349
+ function col2im2d! (w:: NTuple{4,Int} , x:: AbstractArray{T,4} , x2:: AbstractArray{T,2} ,
350
+ n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , mode:: Int ) where T
351
+ Wx,Hx,Cx,Nx = size (x)
352
+ Ww,Hw,C1,C2 = w
353
+ xn = x[:, :, :, n]
354
+ col2im_2d! (x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1 ,1 ,mode)
355
+ x[:, :, :, n] = xn
356
+ return x
357
+ end
358
+
254
359
function col2im2d! (w:: AbstractArray{T,4} , x:: AbstractArray{T,4} , x2:: AbstractArray{T,2} ,
255
360
n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , d1:: Int , d2:: Int , mode:: Int ) where T
256
361
Wx,Hx,Cx,Nx = size (x)
0 commit comments