@@ -30,62 +30,73 @@ function conv(x::A, w::B; pad = 0, stride = 1, dilation = 1) where {A<:AbstractA
30
30
x, w, pad = pad_, stride = stride_, dilation = dilation)
31
31
end
32
32
33
- ∇conv_data (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , dilation = 1 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
34
- ∇conv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
33
+ function crosscor (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
34
+ pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
35
+ crosscor! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
36
+ x, w, pad = pad_, stride = stride_, dilation = dilation)
37
+ end
35
38
36
- ∇conv_filter (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , dilation = 1 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
37
- ∇conv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
39
+ ∇conv_data (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
40
+ ∇conv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel= flipkernel)
41
+
42
+ ∇conv_filter (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
43
+ ∇conv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel= flipkernel)
38
44
39
45
# N-D dispatch
40
46
41
47
function conv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
42
- pad = 0 , stride = 1 , dilation = 1 ) where T
48
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T
43
49
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
44
- conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
50
+ conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ), flipkernel = flipkernel )
45
51
return y
46
52
end
47
53
54
+ function crosscor! (y:: AbstractArray , x:: AbstractArray , w:: AbstractArray ;
55
+ pad = 0 , stride = 1 , dilation = 1 )
56
+ conv! (y, x, w, pad= pad, stride= stride, dilation= dilation, flipkernel= 1 )
57
+ end
58
+
48
59
function ∇conv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
49
60
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
50
- pad = 0 , stride = 1 , dilation = 1 ) where T
61
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T
51
62
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dw, dy, x, w))
52
- ∇conv_filter! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ))
63
+ ∇conv_filter! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ), flipkernel = flipkernel )
53
64
return dw
54
65
end
55
66
56
67
function ∇conv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
57
68
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
58
- pad = 0 , stride = 1 , dilation = 1 ) where T
69
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T
59
70
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dx, dy, x, w))
60
- ∇conv_data! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... , 1 ))
71
+ ∇conv_data! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... , 1 ), flipkernel = flipkernel )
61
72
return dx
62
73
end
63
74
64
75
conv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
65
- pad = 0 , stride = 1 , dilation = 1 ) where T =
66
- conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
76
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
77
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
67
78
68
79
∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
69
- pad = 0 , stride = 1 , dilation = 1 ) where T =
70
- conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
80
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
81
+ conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
71
82
72
83
∇conv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
73
- pad = 0 , stride = 1 , dilation = 1 ) where T =
74
- conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
84
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
85
+ conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
75
86
76
87
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
77
- pad = 0 , stride = 1 , dilation = 1 ) where T =
78
- conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
88
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
89
+ conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
79
90
80
91
∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
81
- pad = 0 , stride = 1 , dilation = 1 ) where T =
82
- conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
92
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
93
+ conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
83
94
84
95
∇conv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
85
- pad = 0 , stride = 1 , dilation = 1 ) where T =
86
- conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
96
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T =
97
+ conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode = flipkernel )
87
98
88
- # Depthwise Conv
99
+ # Depthwise Conv
89
100
90
101
function dcdims (x:: NTuple{4,Int} , w:: NTuple{4,Int} , pad, stride)
91
102
((x[1 ] + 2 * pad[1 ] - w[1 ])÷ stride[1 ] + 1 ,(x[2 ] + 2 * pad[2 ] - w[2 ])÷ stride[2 ] + 1 ,w[3 ]* w[4 ],x[4 ])
@@ -96,23 +107,32 @@ function depthwiseconv(x::A, w::B; pad = 0, stride = 1) where {A<:AbstractArray,
96
107
depthwiseconv! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97
108
end
98
109
110
+ function depthwisecrosscor (x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray
111
+ pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
112
+ depthwisecrosscor! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
113
+ end
114
+
99
115
depthwiseconv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
116
+ pad = 0 , stride = 1 , flipkernel= 0 ) where T =
117
+ depthwiseconv2d! (y, x, w, padding = pad, stride = stride, mode= flipkernel)
118
+
119
+ depthwisecrosscor! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
100
120
pad = 0 , stride = 1 ) where T =
101
- depthwiseconv2d ! (y, x, w, padding = pad, stride = stride)
121
+ depthwiseconv ! (y, x, w, pad = pad, stride = stride, flipkernel = 1 )
102
122
103
- ∇depthwiseconv_data (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
104
- ∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
123
+ ∇depthwiseconv_data (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , flipkernel = 0 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
124
+ ∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride, flipkernel = flipkernel )
105
125
106
- ∇depthwiseconv_filter (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
107
- ∇depthwiseconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride)
126
+ ∇depthwiseconv_filter (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , flipkernel = 0 ) where {A<: AbstractArray , B<: AbstractArray , C<: AbstractArray } =
127
+ ∇depthwiseconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, flipkernel = flipkernel )
108
128
109
129
∇depthwiseconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
110
- pad = 0 , stride = 1 ) where T =
111
- depthwiseconv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride)
130
+ pad = 0 , stride = 1 , flipkernel = 0 ) where T =
131
+ depthwiseconv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, mode = flipkernel )
112
132
113
133
∇depthwiseconv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
114
- pad = 0 , stride = 1 ) where T =
115
- depthwiseconv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride)
134
+ pad = 0 , stride = 1 , flipkernel = 0 ) where T =
135
+ depthwiseconv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, mode = flipkernel )
116
136
117
137
# Pooling
118
138
0 commit comments