39
39
∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
40
40
∇conv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
41
41
42
+ ∇crossconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
43
+ ∇crossconv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
44
+
42
45
∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
43
46
∇conv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
44
47
48
+ ∇crossconv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
49
+ ∇crossconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
50
+
45
51
# N-D dispatch
46
52
47
53
function conv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
@@ -66,6 +72,14 @@ function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
66
72
return dw
67
73
end
68
74
75
+ function ∇crossconv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
76
+ x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
77
+ pad = 0 , stride = 1 , dilation = 1 ) where T
78
+ args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dw, dy, x, w))
79
+ ∇crossconv_filter! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
80
+ return dw
81
+ end
82
+
69
83
function ∇conv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
70
84
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
71
85
pad = 0 , stride = 1 , dilation = 1 ) where T
@@ -74,6 +88,14 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
74
88
return dx
75
89
end
76
90
91
+ function ∇crossconv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
92
+ x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
93
+ pad = 0 , stride = 1 , dilation = 1 ) where T
94
+ args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dx, dy, x, w))
95
+ ∇crossconv_data! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... , 1 ))
96
+ return dx
97
+ end
98
+
77
99
conv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
78
100
pad = 0 , stride = 1 , dilation = 1 ) where T =
79
101
conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
@@ -86,10 +108,18 @@ crossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
86
108
pad = 0 , stride = 1 , dilation = 1 ) where T =
87
109
conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
88
110
111
+ ∇crossconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
112
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
113
+ conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
114
+
89
115
∇conv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
90
116
pad = 0 , stride = 1 , dilation = 1 ) where T =
91
117
conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
92
118
119
+ ∇crossconv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
120
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
121
+ conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
122
+
93
123
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
94
124
pad = 0 , stride = 1 , dilation = 1 ) where T =
95
125
conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
@@ -102,11 +132,20 @@ crossconv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
102
132
pad = 0 , stride = 1 , dilation = 1 ) where T =
103
133
conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
104
134
135
+ ∇crossconv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
136
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
137
+ conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
138
+
105
139
∇conv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
106
140
pad = 0 , stride = 1 , dilation = 1 ) where T =
107
141
conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
108
142
109
- # Depthwise Conv
143
+ ∇crossconv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
144
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
145
+ conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= 1 )
146
+
147
+
148
+ # Depthwise Conv
110
149
111
150
function dcdims (x:: NTuple{4,Int} , w:: NTuple{4,Int} , pad, stride)
112
151
((x[1 ] + 2 * pad[1 ] - w[1 ])÷ stride[1 ] + 1 ,(x[2 ] + 2 * pad[2 ] - w[2 ])÷ stride[2 ] + 1 ,w[3 ]* w[4 ],x[4 ])
@@ -133,17 +172,31 @@ depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArr
133
172
∇depthwiseconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
134
173
∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
135
174
175
+ ∇depthwisecrossconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
176
+ ∇depthwisecrossconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
177
+
136
178
∇depthwiseconv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
137
179
∇depthwiseconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride)
138
180
181
+ ∇depthwisecrossconv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
182
+ ∇depthwisecrossconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride)
183
+
139
184
∇depthwiseconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
140
185
pad = 0 , stride = 1 ) where T =
141
186
depthwiseconv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride)
142
187
188
+ ∇depthwisecrossconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
189
+ pad = 0 , stride = 1 ) where T =
190
+ depthwiseconv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, mode= 1 )
191
+
143
192
∇depthwiseconv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
144
193
pad = 0 , stride = 1 ) where T =
145
194
depthwiseconv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride)
146
195
196
+ ∇depthwisecrossconv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
197
+ pad = 0 , stride = 1 ) where T =
198
+ depthwiseconv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, mode= 1 )
199
+
147
200
# Pooling
148
201
149
202
function pdims (dims:: Dims{N} , window, padding, stride) where N
0 commit comments