Skip to content

Commit 519b5c2

Browse files
authored
Merge pull request #71 from ayush1999/dev_mode
[WIP] mode parameter for Convolution/cross-convolution
2 parents b653dc1 + 56eee8e commit 519b5c2

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

src/conv.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractAr
3030
x, w, pad = pad_, stride = stride_, dilation = dilation)
3131
end
3232

33+
function crossconv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
34+
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
35+
crossconv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
36+
x, w, pad = pad_, stride = stride_, dilation = dilation)
37+
end
38+
3339
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
3440
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
3541

@@ -39,12 +45,17 @@ end
3945
# N-D dispatch
4046

4147
function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
42-
pad = 0, stride = 1, dilation = 1) where T
48+
pad = 0, stride = 1, dilation = 1, flipkernel =0) where T
4349
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
44-
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
50+
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
4551
return y
4652
end
4753

54+
function crossconv!(y::AbstractArray, x::AbstractArray, w::AbstractArray;
55+
pad = 0, stride = 1, dilation = 1)
56+
conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1)
57+
end
58+
4859
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
4960
x::AbstractArray{T,3}, w::AbstractArray{T,3};
5061
pad = 0, stride = 1, dilation = 1) where T
@@ -62,8 +73,8 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
6273
end
6374

6475
conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
65-
pad = 0, stride = 1, dilation = 1) where T =
66-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
76+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
77+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
6778

6879
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
6980
pad = 0, stride = 1, dilation = 1) where T =
@@ -74,8 +85,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
7485
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
7586

7687
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
77-
pad = 0, stride = 1, dilation = 1) where T =
78-
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
88+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
89+
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
7990

8091
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
8192
pad = 0, stride = 1, dilation = 1) where T =
@@ -85,7 +96,7 @@ conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
8596
pad = 0, stride = 1, dilation = 1) where T =
8697
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
8798

88-
# Depthwise Conv
99+
# Depthwise Conv
89100

90101
function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
91102
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
@@ -96,9 +107,18 @@ function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
96107
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97108
end
98109

110+
function depthwisecrossconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
111+
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
112+
depthwisecrossconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
113+
end
114+
99115
depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
116+
pad = 0, stride = 1, flipkernel=0) where T =
117+
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode= flipkernel)
118+
119+
depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
100120
pad = 0, stride = 1) where T =
101-
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
121+
depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1)
102122

103123
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
104124
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)

0 commit comments

Comments
 (0)