@@ -30,6 +30,12 @@ function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractAr
30
30
x, w, pad = pad_, stride = stride_, dilation = dilation)
31
31
end
32
32
33
+ function crossconv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
34
+ pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
35
+ crossconv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
36
+ x, w, pad = pad_, stride = stride_, dilation = dilation)
37
+ end
38
+
33
39
∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
34
40
∇conv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
35
41
39
45
# N-D dispatch
40
46
41
47
function conv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
42
- pad = 0 , stride = 1 , dilation = 1 ) where T
48
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T
43
49
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
44
- conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
50
+ conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ), flipkernel = flipkernel )
45
51
return y
46
52
end
47
53
54
+ function crossconv! (y:: AbstractArray , x:: AbstractArray , w:: AbstractArray ;
55
+ pad = 0 , stride = 1 , dilation = 1 )
56
+ conv! (y, x, w, pad= pad, stride= stride, dilation= dilation, flipkernel= 1 )
57
+ end
58
+
48
59
function ∇conv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
49
60
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
50
61
pad = 0 , stride = 1 , dilation = 1 ) where T
@@ -62,8 +73,8 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
62
73
end
63
74
64
75
conv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
65
- pad = 0 , stride = 1 , dilation = 1 ) where T =
66
- conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
76
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
77
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
67
78
68
79
∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
69
80
pad = 0 , stride = 1 , dilation = 1 ) where T =
@@ -74,8 +85,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
74
85
conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
75
86
76
87
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
77
- pad = 0 , stride = 1 , dilation = 1 ) where T =
78
- conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
88
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
89
+ conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
79
90
80
91
∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
81
92
pad = 0 , stride = 1 , dilation = 1 ) where T =
@@ -85,7 +96,7 @@ conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
85
96
pad = 0 , stride = 1 , dilation = 1 ) where T =
86
97
conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
87
98
88
- # Depthwise Conv
99
+ # Depthwise Conv
89
100
90
101
function dcdims (x:: NTuple{4,Int} , w:: NTuple{4,Int} , pad, stride)
91
102
((x[1 ] + 2 * pad[1 ] - w[1 ])÷ stride[1 ] + 1 ,(x[2 ] + 2 * pad[2 ] - w[2 ])÷ stride[2 ] + 1 ,w[3 ]* w[4 ],x[4 ])
@@ -96,9 +107,18 @@ function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
96
107
depthwiseconv! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97
108
end
98
109
110
+ function depthwisecrossconv (x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray
111
+ pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
112
+ depthwisecrossconv! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
113
+ end
114
+
99
115
depthwiseconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
116
+ pad = 0 , stride = 1 , flipkernel= 0 ) where T =
117
+ depthwiseconv2d! (y, x, w, padding = pad, stride = stride, mode= flipkernel)
118
+
119
+ depthwisecrossconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
100
120
pad = 0 , stride = 1 ) where T =
101
- depthwiseconv2d ! (y, x, w, padding = pad, stride = stride)
121
+ depthwiseconv ! (y, x, w, pad = pad, stride = stride, flipkernel = 1 )
102
122
103
123
∇depthwiseconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
104
124
∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
0 commit comments