Skip to content

Commit c4d301d

Browse files
committed
crossconv for depthwise conv
1 parent 4584bef commit c4d301d

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

src/conv.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,15 @@ end
3939
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
4040
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
4141

42+
∇crossconv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
43+
∇crossconv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
44+
4245
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
4346
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
4447

48+
∇crossconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
49+
∇crossconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
50+
4551
# N-D dispatch
4652

4753
function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
@@ -66,6 +72,14 @@ function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
6672
return dw
6773
end
6874

75+
function ∇crossconv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
76+
x::AbstractArray{T,3}, w::AbstractArray{T,3};
77+
pad = 0, stride = 1, dilation = 1) where T
78+
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
79+
∇crossconv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
80+
return dw
81+
end
82+
6983
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
7084
x::AbstractArray{T,3}, w::AbstractArray{T,3};
7185
pad = 0, stride = 1, dilation = 1) where T
@@ -74,6 +88,14 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
7488
return dx
7589
end
7690

91+
function ∇crossconv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
92+
x::AbstractArray{T,3}, w::AbstractArray{T,3};
93+
pad = 0, stride = 1, dilation = 1) where T
94+
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
95+
∇crossconv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1))
96+
return dx
97+
end
98+
7799
conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
78100
pad = 0, stride = 1, dilation = 1) where T =
79101
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
@@ -86,10 +108,18 @@ crossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
86108
pad = 0, stride = 1, dilation = 1) where T =
87109
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
88110

111+
∇crossconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
112+
pad = 0, stride = 1, dilation = 1) where T =
113+
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
114+
89115
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
90116
pad = 0, stride = 1, dilation = 1) where T =
91117
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
92118

119+
∇crossconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
120+
pad = 0, stride = 1, dilation = 1) where T =
121+
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
122+
93123
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
94124
pad = 0, stride = 1, dilation = 1) where T =
95125
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
@@ -102,11 +132,20 @@ crossconv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
102132
pad = 0, stride = 1, dilation = 1) where T =
103133
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
104134

135+
∇crossconv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
136+
pad = 0, stride = 1, dilation = 1) where T =
137+
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
138+
105139
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
106140
pad = 0, stride = 1, dilation = 1) where T =
107141
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
108142

109-
# Depthwise Conv
143+
∇crossconv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
144+
pad = 0, stride = 1, dilation = 1) where T =
145+
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)
146+
147+
148+
# Depthwise Conv
110149

111150
function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
112151
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
@@ -133,17 +172,31 @@ depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArr
133172
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
134173
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
135174

175+
∇depthwisecrossconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
176+
∇depthwisecrossconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
177+
136178
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
137179
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
138180

181+
∇depthwisecrossconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
182+
∇depthwisecrossconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
183+
139184
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
140185
pad = 0, stride = 1) where T =
141186
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)
142187

188+
∇depthwisecrossconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
189+
pad = 0, stride = 1) where T =
190+
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, mode=1)
191+
143192
∇depthwiseconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
144193
pad = 0, stride = 1) where T =
145194
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)
146195

196+
∇depthwisecrossconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
197+
pad = 0, stride = 1) where T =
198+
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, mode=1)
199+
147200
# Pooling
148201

149202
function pdims(dims::Dims{N}, window, padding, stride) where N

0 commit comments

Comments
 (0)