39
39
∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
40
40
∇conv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
41
41
42
- ∇crossconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
43
- ∇crossconv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
44
-
45
42
∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
46
43
∇conv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
47
44
48
- ∇crossconv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
49
- ∇crossconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
50
-
51
45
# N-D dispatch
52
46
53
47
function conv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
54
- pad = 0 , stride = 1 , dilation = 1 ) where T
48
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T
55
49
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
56
- conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
50
+ conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ), flipkernel = flipkernel )
57
51
return y
58
52
end
59
53
60
- function crossconv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
61
- pad = 0 , stride = 1 , dilation = 1 ) where T
62
- args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
63
- crossconv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
64
- return y
54
+ function crossconv! (y:: AbstractArray , x:: AbstractArray , w:: AbstractArray ;
55
+ pad = 0 , stride = 1 , dilation = 1 )
56
+ conv! (y, x, w, pad= pad, stride= stride, dilation= dilation, flipkernel= 1 )
65
57
end
66
58
67
59
function ∇conv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
@@ -72,14 +64,6 @@ function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
72
64
return dw
73
65
end
74
66
75
- function ∇crossconv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
76
- x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
77
- pad = 0 , stride = 1 , dilation = 1 ) where T
78
- args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dw, dy, x, w))
79
- ∇crossconv_filter! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
80
- return dw
81
- end
82
-
83
67
function ∇conv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
84
68
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
85
69
pad = 0 , stride = 1 , dilation = 1 ) where T
@@ -88,63 +72,30 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
88
72
return dx
89
73
end
90
74
91
- function ∇crossconv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
92
- x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
93
- pad = 0 , stride = 1 , dilation = 1 ) where T
94
- args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dx, dy, x, w))
95
- ∇crossconv_data! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... , 1 ))
96
- return dx
97
- end
98
-
99
75
conv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
100
- pad = 0 , stride = 1 , dilation = 1 ) where T =
101
- conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
102
-
103
- crossconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
104
- pad = 0 , stride = 1 , dilation = 1 ) where T =
105
- conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= 1 )
76
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
77
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
106
78
107
79
∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
108
80
pad = 0 , stride = 1 , dilation = 1 ) where T =
109
81
conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
110
82
111
- ∇crossconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
112
- pad = 0 , stride = 1 , dilation = 1 ) where T =
113
- conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
114
-
115
83
∇conv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
116
84
pad = 0 , stride = 1 , dilation = 1 ) where T =
117
85
conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
118
86
119
- ∇crossconv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
120
- pad = 0 , stride = 1 , dilation = 1 ) where T =
121
- conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
122
-
123
87
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
124
- pad = 0 , stride = 1 , dilation = 1 ) where T =
125
- conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
126
-
127
- crossconv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
128
- pad = 0 , stride = 1 , dilation = 1 ) where T =
129
- conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= 1 )
88
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
89
+ conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
130
90
131
91
∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
132
92
pad = 0 , stride = 1 , dilation = 1 ) where T =
133
93
conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
134
94
135
- ∇crossconv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
136
- pad = 0 , stride = 1 , dilation = 1 ) where T =
137
- conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
138
-
139
95
∇conv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
140
96
pad = 0 , stride = 1 , dilation = 1 ) where T =
141
97
conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
142
98
143
- ∇crossconv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
144
- pad = 0 , stride = 1 , dilation = 1 ) where T =
145
- conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
146
-
147
-
148
99
# Depthwise Conv
149
100
150
101
function dcdims (x:: NTuple{4,Int} , w:: NTuple{4,Int} , pad, stride)
@@ -162,41 +113,27 @@ function depthwisecrossconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractAr
162
113
end
163
114
164
115
depthwiseconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
165
- pad = 0 , stride = 1 ) where T =
166
- depthwiseconv2d! (y, x, w, padding = pad, stride = stride)
116
+ pad = 0 , stride = 1 , flipkernel = 0 ) where T =
117
+ depthwiseconv2d! (y, x, w, padding = pad, stride = stride, mode = flipkernel )
167
118
168
119
depthwisecrossconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
169
120
pad = 0 , stride = 1 ) where T =
170
- depthwiseconv2d ! (y, x, w, padding = pad, stride = stride, mode = 1 )
121
+ depthwiseconv ! (y, x, w, pad = pad, stride = stride, flipkernel = 1 )
171
122
172
123
∇depthwiseconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
173
124
∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
174
125
175
- ∇depthwisecrossconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
176
- ∇depthwisecrossconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
177
-
178
126
∇depthwiseconv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
179
127
∇depthwiseconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride)
180
128
181
- ∇depthwisecrossconv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
182
- ∇depthwisecrossconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride)
183
-
184
129
∇depthwiseconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
185
130
pad = 0 , stride = 1 ) where T =
186
131
depthwiseconv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride)
187
132
188
- ∇depthwisecrossconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
189
- pad = 0 , stride = 1 ) where T =
190
- depthwiseconv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, mode= 1 )
191
-
192
133
∇depthwiseconv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
193
134
pad = 0 , stride = 1 ) where T =
194
135
depthwiseconv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride)
195
136
196
- ∇depthwisecrossconv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
197
- pad = 0 , stride = 1 ) where T =
198
- depthwiseconv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, mode= 1 )
199
-
200
137
# Pooling
201
138
202
139
function pdims (dims:: Dims{N} , window, padding, stride) where N
0 commit comments