39
39
# N-D dispatch
40
40
41
41
function conv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
42
- pad = 0 , stride = 1 , dilation = 1 ) where T
42
+ pad = 0 , stride = 1 , dilation = 1 , mode = 0 ) where T
43
43
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
44
- conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
44
+ conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ), mode = mode )
45
45
return y
46
46
end
47
47
@@ -74,8 +74,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
74
74
conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
75
75
76
76
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
77
- pad = 0 , stride = 1 , dilation = 1 ) where T =
78
- conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
77
+ pad = 0 , stride = 1 , dilation = 1 , mode = 0 ) where T =
78
+ conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = mode )
79
79
80
80
∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
81
81
pad = 0 , stride = 1 , dilation = 1 ) where T =
@@ -91,14 +91,14 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
91
91
((x[1 ] + 2 * pad[1 ] - w[1 ])÷ stride[1 ] + 1 ,(x[2 ] + 2 * pad[2 ] - w[2 ])÷ stride[2 ] + 1 ,w[3 ]* w[4 ],x[4 ])
92
92
end
93
93
94
- function depthwiseconv (x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray
94
+ function depthwiseconv (x:: A , w:: A ; pad = 0 , stride = 1 , mode = 0 ) where A<: AbstractArray
95
95
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
96
- depthwiseconv! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
96
+ depthwiseconv! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_, mode = mode )
97
97
end
98
98
99
99
depthwiseconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
100
- pad = 0 , stride = 1 ) where T =
101
- depthwiseconv2d! (y, x, w, padding = pad, stride = stride)
100
+ pad = 0 , stride = 1 , mode = 0 ) where T =
101
+ depthwiseconv2d! (y, x, w, padding = pad, stride = stride, mode = mode )
102
102
103
103
∇depthwiseconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
104
104
∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
0 commit comments