Skip to content

Commit 4584bef

Browse files
committed
Adapted to previous changes
1 parent 25dba3d commit 4584bef

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

src/conv.jl

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,16 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
2424
padtuple(x::Tuple,p::Tuple) = p
2525
padtuple(x::AbstractArray,p) = padtuple(size(x),p)
2626

27-
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, mode = 0) where A<:AbstractArray
27+
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
2828
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
2929
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
30-
x, w, pad = pad_, stride = stride_, dilation = dilation, mode = mode)
30+
x, w, pad = pad_, stride = stride_, dilation = dilation)
31+
end
32+
33+
function crossconv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
34+
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
35+
crossconv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
36+
x, w, pad = pad_, stride = stride_, dilation = dilation)
3137
end
3238

3339
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
@@ -39,12 +45,19 @@ end
3945
# N-D dispatch
4046

4147
function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
42-
pad = 0, stride = 1, dilation = 1, mode=0) where T
48+
pad = 0, stride = 1, dilation = 1) where T
4349
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
44-
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), mode=mode)
50+
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
4551
return y
4652
end
4753

54+
function crossconv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
55+
pad = 0, stride = 1, dilation = 1) where T
56+
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
57+
crossconv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
58+
return y
59+
end
60+
4861
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
4962
x::AbstractArray{T,3}, w::AbstractArray{T,3};
5063
pad = 0, stride = 1, dilation = 1) where T
@@ -62,8 +75,12 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
6275
end
6376

6477
conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
65-
pad = 0, stride = 1, dilation = 1, mode=0) where T =
66-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode)
78+
pad = 0, stride = 1, dilation = 1) where T =
79+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
80+
81+
crossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
82+
pad = 0, stride = 1, dilation = 1) where T =
83+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)
6784

6885
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
6986
pad = 0, stride = 1, dilation = 1) where T =
@@ -74,8 +91,12 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
7491
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
7592

7693
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
77-
pad = 0, stride = 1, dilation = 1, mode=0) where T =
78-
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode)
94+
pad = 0, stride = 1, dilation = 1) where T =
95+
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
96+
97+
crossconv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
98+
pad = 0, stride = 1, dilation = 1) where T =
99+
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)
79100

80101
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
81102
pad = 0, stride = 1, dilation = 1) where T =
@@ -91,14 +112,23 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
91112
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
92113
end
93114

94-
function depthwiseconv(x::A, w::A; pad = 0, stride = 1, mode=0) where A<:AbstractArray
115+
function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
116+
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
117+
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
118+
end
119+
120+
function depthwisecrossconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
95121
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
96-
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_, mode=mode)
122+
depthwisecrossconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97123
end
98124

99125
depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
100-
pad = 0, stride = 1, mode=0) where T =
101-
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode=mode)
126+
pad = 0, stride = 1) where T =
127+
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
128+
129+
depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
130+
pad = 0, stride = 1) where T =
131+
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode=1)
102132

103133
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
104134
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)

0 commit comments

Comments
 (0)