Skip to content

Commit 56eee8e

Browse files
committed
Removed duplicate code
1 parent c4d301d commit 56eee8e

File tree

1 file changed

+12
-75
lines changed

1 file changed

+12
-75
lines changed

src/conv.jl

Lines changed: 12 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -39,29 +39,21 @@ end
3939
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
4040
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
4141

42-
∇crossconv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
43-
∇crossconv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
44-
4542
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
4643
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
4744

48-
∇crossconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
49-
∇crossconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
50-
5145
# N-D dispatch
5246

5347
function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
54-
pad = 0, stride = 1, dilation = 1) where T
48+
pad = 0, stride = 1, dilation = 1, flipkernel =0) where T
5549
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
56-
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
50+
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
5751
return y
5852
end
5953

60-
function crossconv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
61-
pad = 0, stride = 1, dilation = 1) where T
62-
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
63-
crossconv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
64-
return y
54+
function crossconv!(y::AbstractArray, x::AbstractArray, w::AbstractArray;
55+
pad = 0, stride = 1, dilation = 1)
56+
conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1)
6557
end
6658

6759
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
@@ -72,14 +64,6 @@ function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
7264
return dw
7365
end
7466

75-
function ∇crossconv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
76-
x::AbstractArray{T,3}, w::AbstractArray{T,3};
77-
pad = 0, stride = 1, dilation = 1) where T
78-
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
79-
∇crossconv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
80-
return dw
81-
end
82-
8367
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
8468
x::AbstractArray{T,3}, w::AbstractArray{T,3};
8569
pad = 0, stride = 1, dilation = 1) where T
@@ -88,63 +72,30 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
8872
return dx
8973
end
9074

91-
function ∇crossconv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
92-
x::AbstractArray{T,3}, w::AbstractArray{T,3};
93-
pad = 0, stride = 1, dilation = 1) where T
94-
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
95-
∇crossconv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1))
96-
return dx
97-
end
98-
9975
conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
100-
pad = 0, stride = 1, dilation = 1) where T =
101-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
102-
103-
crossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
104-
pad = 0, stride = 1, dilation = 1) where T =
105-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)
76+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
77+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
10678

10779
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
10880
pad = 0, stride = 1, dilation = 1) where T =
10981
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
11082

111-
∇crossconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
112-
pad = 0, stride = 1, dilation = 1) where T =
113-
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
114-
11583
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
11684
pad = 0, stride = 1, dilation = 1) where T =
11785
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
11886

119-
∇crossconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
120-
pad = 0, stride = 1, dilation = 1) where T =
121-
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
122-
12387
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
124-
pad = 0, stride = 1, dilation = 1) where T =
125-
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
126-
127-
crossconv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
128-
pad = 0, stride = 1, dilation = 1) where T =
129-
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)
88+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
89+
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
13090

13191
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
13292
pad = 0, stride = 1, dilation = 1) where T =
13393
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
13494

135-
∇crossconv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
136-
pad = 0, stride = 1, dilation = 1) where T =
137-
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
138-
13995
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
14096
pad = 0, stride = 1, dilation = 1) where T =
14197
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
14298

143-
∇crossconv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
144-
pad = 0, stride = 1, dilation = 1) where T =
145-
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
146-
147-
14899
# Depthwise Conv
149100

150101
function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
@@ -162,41 +113,27 @@ function depthwisecrossconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractAr
162113
end
163114

164115
depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
165-
pad = 0, stride = 1) where T =
166-
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
116+
pad = 0, stride = 1, flipkernel=0) where T =
117+
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode= flipkernel)
167118

168119
depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
169120
pad = 0, stride = 1) where T =
170-
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode=1)
121+
depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1)
171122

172123
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
173124
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
174125

175-
∇depthwisecrossconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
176-
∇depthwisecrossconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
177-
178126
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
179127
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
180128

181-
∇depthwisecrossconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
182-
∇depthwisecrossconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
183-
184129
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
185130
pad = 0, stride = 1) where T =
186131
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)
187132

188-
∇depthwisecrossconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
189-
pad = 0, stride = 1) where T =
190-
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, mode=1)
191-
192133
∇depthwiseconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
193134
pad = 0, stride = 1) where T =
194135
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)
195136

196-
∇depthwisecrossconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
197-
pad = 0, stride = 1) where T =
198-
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, mode=1)
199-
200137
# Pooling
201138

202139
function pdims(dims::Dims{N}, window, padding, stride) where N

0 commit comments

Comments
 (0)