@@ -17,30 +17,74 @@ function cdims(x::NTuple{N}, w::NTuple{N}, pad, stride) where N
17
17
end
18
18
end
19
19
20
+
21
+ # Conv Transpose dims
22
+
23
+ function ctdims (x:: NTuple{N} , w:: NTuple{N} , pad, stride, dilation) where N
24
+ ntuple (Val (N)) do i
25
+ if i < N- 1
26
+ (x[i] - 1 ) * stride[i] + dilation[i] * (w[i] - 1 ) - 2 * pad[i] + 1
27
+ elseif i == N- 1
28
+ w[N- 1 ]
29
+ else # i == N
30
+ x[N]
31
+ end
32
+ end
33
+ end
34
+
35
+
36
+ # Kernel dims
37
+
38
+ function wdims (x:: NTuple{N} , y:: NTuple{N} , pad, stride, dilation) where N
39
+ ntuple (Val (N)) do i
40
+ if i < N- 1
41
+ 1 + div ((1 - y[i]) * stride[i] + x[i] + 2 pad[i] - 1 , dilation[i])
42
+ elseif i == N- 1
43
+ x[i]
44
+ else # i == N
45
+ y[i- 1 ]
46
+ end
47
+ end
48
+ end
49
+
20
50
# Interface
21
51
22
52
head (x) = reverse (Base. tail (reverse (x)))
23
53
padtuple (x:: Tuple ,p:: Integer ) = map (_-> p, head (head (x)))
24
54
padtuple (x:: Tuple ,p:: Tuple ) = p
25
55
padtuple (x:: AbstractArray ,p) = padtuple (size (x),p)
26
56
27
- function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
57
+ function conv (x:: A , w:: A ; size = nothing , pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
28
58
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
29
- conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
30
- x, w, pad = pad_, stride = stride_, dilation = dilation)
59
+ if size === nothing
60
+ size = cdims (Base. size (x), dilation_dims (w, dilation), pad_, stride_)
61
+ end
62
+ conv! (similar (x, size), x, w, pad = pad_, stride = stride_, dilation = dilation)
31
63
end
32
64
33
- function crosscor (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
65
+ function crosscor (x:: A , w:: A ; size = nothing , pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
34
66
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
35
- crosscor! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
36
- x, w, pad = pad_, stride = stride_, dilation = dilation)
67
+ if size === nothing
68
+ size = cdims (Base. size (x), dilation_dims (w, dilation), pad_, stride_)
69
+ end
70
+ crosscor! (similar (x, size), x, w, pad = pad_, stride = stride_, dilation = dilation)
37
71
end
38
72
39
- ∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where A<: AbstractArray =
40
- ∇conv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel= flipkernel)
73
+ function ∇conv_data (dy:: A , w:: A ; size= nothing , pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where A<: AbstractArray
74
+ pad_, stride_, dilation_ = padtuple (dy, pad), padtuple (dy, stride), padtuple (dy, dilation)
75
+ if size === nothing
76
+ size = ctdims (Base. size (dy), Base. size (w), pad_, stride_, dilation_)
77
+ end
78
+ ∇conv_data! (similar (dy, size), dy, w, pad = pad_, stride = stride_, dilation = dilation_, flipkernel= flipkernel)
79
+ end
41
80
42
- ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where A<: AbstractArray =
43
- ∇conv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel= flipkernel)
81
+ function ∇conv_filter (dy:: A , x:: A ; size = nothing , pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where A<: AbstractArray
82
+ pad_, stride_, dilation_ = padtuple (dy, pad), padtuple (dy, stride), padtuple (dy, dilation)
83
+ if size === nothing
84
+ size = wdims (Base. size (x), Base. size (dy), pad_, stride_, dilation_)
85
+ end
86
+ ∇conv_filter! (zero (similar (dy, size)), dy, x; pad = pad, stride = stride, dilation = dilation, flipkernel= flipkernel)
87
+ end
44
88
45
89
# N-D dispatch
46
90
@@ -56,18 +100,16 @@ function crosscor!(y::AbstractArray, x::AbstractArray, w::AbstractArray;
56
100
conv! (y, x, w, pad= pad, stride= stride, dilation= dilation, flipkernel= 1 )
57
101
end
58
102
59
- function ∇conv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
60
- x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
103
+ function ∇conv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} , x:: AbstractArray{T,3} ;
61
104
pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T
62
- args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dw, dy, x, w ))
105
+ args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dw, dy, x))
63
106
∇conv_filter! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... ,1 ), flipkernel= flipkernel)
64
107
return dw
65
108
end
66
109
67
- function ∇conv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
68
- x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
69
- pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T
70
- args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dx, dy, x, w))
110
+ function ∇conv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
111
+ pad = 0 , stride = 1 , dilation = 1 , flipkernel = 0 ) where T
112
+ args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dx, dy, w))
71
113
∇conv_data! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation... , 1 ), flipkernel = flipkernel)
72
114
return dx
73
115
end
@@ -76,25 +118,25 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
76
118
pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
77
119
conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
78
120
79
- ∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w :: AbstractArray{T,4} ;
121
+ ∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} ;
80
122
pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
81
- conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
123
+ conv2d_grad_w! (dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
82
124
83
- ∇conv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x :: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
125
+ ∇conv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
84
126
pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
85
- conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
127
+ conv2d_grad_x! (dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
86
128
87
129
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
88
130
pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
89
131
conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
90
132
91
- ∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w :: AbstractArray{T,5} ;
133
+ ∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} ;
92
134
pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
93
- conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
135
+ conv3d_grad_w! (dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
94
136
95
- ∇conv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x :: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
137
+ ∇conv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
96
138
pad = 0 , stride = 1 , dilation = 1 , flipkernel= 0 ) where T =
97
- conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
139
+ conv3d_grad_x! (dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode= flipkernel)
98
140
99
141
# Depthwise Conv
100
142
@@ -216,3 +258,9 @@ meanpool_cpu!(y::AbstractArray{<:Real,5}, x::AbstractArray{<:Real,5}, k::Dims{3}
216
258
k:: Dims{3} ; pad = (0 ,0 ), stride = k) =
217
259
meanpool3d_grad! (dx, dy, y, x,
218
260
window = k, padding = pad, stride = stride)
261
+
262
+ # Deprecated
263
+
264
+ # 0.4.2
265
+ @deprecate ∇conv_data (dy:: A , x:: A , w:: A ; kw... ) where A<: AbstractArray ∇conv_data (dy, w; size= size (x), kw... )
266
+ @deprecate ∇conv_filter (dy:: A , x:: A , w:: A ; kw... ) where A<: AbstractArray ∇conv_filter (dy, x; size= size (w), kw... )
0 commit comments