@@ -26,20 +26,20 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
26
26
# cdims = ConvDims(x, w; stride=2, dilation=(3,2))
27
27
# dx = ∇conv_data(conv(x, w, cdims), w, cdims)
28
28
29
- # The computational flow, starting from the user facing functions,
30
- # goes through the following steps:
29
+ # The computational flow, starting from the user facing functions,
30
+ # goes through the following steps:
31
31
#
32
- # STEP 1:
32
+ # STEP 1:
33
33
# use ConvDims objects (only for `conv` and `depthwiseconv`)
34
- # STEP 2:
34
+ # STEP 2:
35
35
# define autoallocating version (frontend and implementations)
36
- # STEP 3:
36
+ # STEP 3:
37
37
# reshape to 3d convolutions (frontend and implementions)
38
- # STEP 4:
38
+ # STEP 4:
39
39
# choose implementation
40
40
41
41
# TODO : should we also add
42
- # STEP X:
42
+ # STEP X:
43
43
# use homogeneus datatypes
44
44
# to handle etherogeneus inputs now handled by conv_direct?
45
45
@@ -48,22 +48,23 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
48
48
"""
49
49
conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)
50
50
51
- Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
52
- in 1d/2d/3d convolutions respectively.
51
+ Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
52
+ in 1d/2d/3d convolutions respectively.
53
53
"""
54
- function conv (x, w:: AbstractArray{T, N} ; stride= 1 , pad= 0 , dilation= 1 , flipped= false , groups = 1 ) where {T, N}
55
- stride = expand (Val (N- 2 ), stride)
56
- pad = expand (Val (N- 2 ), pad)
57
- dilation = expand (Val (N- 2 ), dilation)
58
- cdims = DenseConvDims (x, w; stride= stride, padding= pad, dilation= dilation, flipkernel= flipped, groups = groups)
54
+ function conv (x, w:: AbstractArray{T, N} ; stride = 1 , pad = 0 , dilation = 1 , flipped = false , groups = 1 ) where {T, N}
55
+ stride = expand (Val (N - 2 ), stride)
56
+ padding = expand (Val (N - 2 ), pad)
57
+ dilation = expand (Val (N - 2 ), dilation)
58
+ cdims = DenseConvDims (
59
+ size (x), size (w); stride, padding, dilation, flipkernel= flipped, groups)
59
60
return conv (x, w, cdims)
60
61
end
61
62
62
63
"""
63
64
depthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false)
64
65
65
- Depthwise convolution operation with filter `w` on input `x`. `x` and `w`
66
- are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
66
+ Depthwise convolution operation with filter `w` on input `x`. `x` and `w`
67
+ are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
67
68
"""
68
69
function depthwiseconv (x, w:: AbstractArray{T, N} ; stride= 1 , pad= 0 , dilation= 1 , flipped= false ) where {T, N}
69
70
stride = expand (Val (N- 2 ), stride)
@@ -98,9 +99,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
98
99
function $ (Symbol (" $(name)$(backend) " ))(
99
100
dy:: AbstractArray{yT,N} , w:: AbstractArray{wT,N} ,
100
101
cdims:: C ; kwargs... ) where {yT, wT, N, C <: ConvDims }
101
- dx = similar (dy, input_size (cdims)... , channels_in (cdims),
102
- size (dy, N))
103
-
102
+ dx = similar (dy, input_size (cdims)... , channels_in (cdims), size (dy, N))
104
103
return $ (Symbol (" $(name)$(backend) !" ))(dx, dy, w, cdims; kwargs... )
105
104
end
106
105
end
@@ -114,7 +113,6 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
114
113
cdims:: ConvDims ; kwargs... ) where {xT, yT, N}
115
114
dw = similar (dy, kernel_size (cdims)... , channels_in (cdims) ÷ groupcount (cdims),
116
115
channels_out (cdims))
117
-
118
116
return $ (Symbol (" ∇conv_filter$(backend) !" ))(dw, x, dy, cdims; kwargs... )
119
117
end
120
118
end
@@ -197,15 +195,15 @@ for (front_name, backend) in (
197
195
G = 1 ,
198
196
C_in = channels_in (cdims) ÷ groupcount (cdims),
199
197
C_out = channels_out (cdims) ÷ groupcount (cdims))
200
-
198
+
201
199
Threads. @sync for (xc, wc) in zip (x_cs, w_cs)
202
200
x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
203
201
w = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
204
202
y = @view out[ntuple (i -> i == 4 ? wc : Colon (), 5 )... ]
205
203
Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(y, x, w, cdims2; kwargs... )
206
204
end
207
205
208
- return out
206
+ return out
209
207
end
210
208
end
211
209
end
@@ -232,12 +230,11 @@ function ∇conv_data!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
232
230
Threads. @spawn ∇conv_data_im2col! (dxv, dyv, wv, cdims2; kwargs... )
233
231
end
234
232
235
- return out
233
+ return out
236
234
end
237
235
238
236
function ∇conv_filter! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
239
237
in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
240
-
241
238
dw_cs = Iterators. partition (1 : size (out, 5 ),
242
239
channels_out (cdims) ÷ groupcount (cdims))
243
240
dy_cs = Iterators. partition (1 : size (in2, 4 ),
@@ -256,7 +253,7 @@ function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
256
253
Threads. @spawn ∇conv_filter_im2col! (dw, x, dy, cdims2; kwargs... )
257
254
end
258
255
259
- return out
256
+ return out
260
257
end
261
258
262
259
0 commit comments