@@ -24,10 +24,16 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
24
24
padtuple (x:: Tuple ,p:: Tuple ) = p
25
25
padtuple (x:: AbstractArray ,p) = padtuple (size (x),p)
26
26
27
- function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , mode = 0 ) where A<: AbstractArray
27
+ function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
28
28
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
29
29
conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
30
- x, w, pad = pad_, stride = stride_, dilation = dilation, mode = mode)
30
+ x, w, pad = pad_, stride = stride_, dilation = dilation)
31
+ end
32
+
33
+ function crossconv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
34
+ pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
35
+ crossconv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
36
+ x, w, pad = pad_, stride = stride_, dilation = dilation)
31
37
end
32
38
33
39
∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
39
45
# N-D dispatch
40
46
41
47
function conv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
42
- pad = 0 , stride = 1 , dilation = 1 , mode = 0 ) where T
48
+ pad = 0 , stride = 1 , dilation = 1 ) where T
43
49
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
44
- conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ), mode = mode )
50
+ conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
45
51
return y
46
52
end
47
53
54
+ function crossconv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
55
+ pad = 0 , stride = 1 , dilation = 1 ) where T
56
+ args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
57
+ crossconv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
58
+ return y
59
+ end
60
+
48
61
function ∇conv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
49
62
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
50
63
pad = 0 , stride = 1 , dilation = 1 ) where T
@@ -62,8 +75,12 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
62
75
end
63
76
64
77
conv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
65
- pad = 0 , stride = 1 , dilation = 1 , mode= 0 ) where T =
66
- conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= mode)
78
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
79
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
80
+
81
+ crossconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
82
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
83
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= 1 )
67
84
68
85
∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
69
86
pad = 0 , stride = 1 , dilation = 1 ) where T =
@@ -74,8 +91,12 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
74
91
conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
75
92
76
93
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
77
- pad = 0 , stride = 1 , dilation = 1 , mode= 0 ) where T =
78
- conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= mode)
94
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
95
+ conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
96
+
97
+ crossconv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
98
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
99
+ conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= 1 )
79
100
80
101
∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
81
102
pad = 0 , stride = 1 , dilation = 1 ) where T =
@@ -91,14 +112,23 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
91
112
((x[1 ] + 2 * pad[1 ] - w[1 ])÷ stride[1 ] + 1 ,(x[2 ] + 2 * pad[2 ] - w[2 ])÷ stride[2 ] + 1 ,w[3 ]* w[4 ],x[4 ])
92
113
end
93
114
94
- function depthwiseconv (x:: A , w:: A ; pad = 0 , stride = 1 , mode= 0 ) where A<: AbstractArray
115
+ function depthwiseconv (x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray
116
+ pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
117
+ depthwiseconv! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
118
+ end
119
+
120
+ function depthwisecrossconv (x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray
95
121
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
96
- depthwiseconv ! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_, mode = mode )
122
+ depthwisecrossconv ! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97
123
end
98
124
99
125
depthwiseconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
100
- pad = 0 , stride = 1 , mode= 0 ) where T =
101
- depthwiseconv2d! (y, x, w, padding = pad, stride = stride, mode= mode)
126
+ pad = 0 , stride = 1 ) where T =
127
+ depthwiseconv2d! (y, x, w, padding = pad, stride = stride)
128
+
129
+ depthwisecrossconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
130
+ pad = 0 , stride = 1 ) where T =
131
+ depthwiseconv2d! (y, x, w, padding = pad, stride = stride, mode= 1 )
102
132
103
133
∇depthwiseconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
104
134
∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
0 commit comments