Skip to content

Commit 25dba3d

Browse files
committed
Changes for depthwiseconv
1 parent b94257f commit 25dba3d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/conv.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ end
3939
# N-D dispatch
4040

4141
function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
42-
pad = 0, stride = 1, dilation = 1) where T
42+
pad = 0, stride = 1, dilation = 1, mode=0) where T
4343
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
44-
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
44+
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), mode=mode)
4545
return y
4646
end
4747

@@ -74,8 +74,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
7474
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
7575

7676
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
77-
pad = 0, stride = 1, dilation = 1) where T =
78-
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
77+
pad = 0, stride = 1, dilation = 1, mode=0) where T =
78+
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode)
7979

8080
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
8181
pad = 0, stride = 1, dilation = 1) where T =
@@ -91,14 +91,14 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
9191
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
9292
end
9393

94-
function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
94+
function depthwiseconv(x::A, w::A; pad = 0, stride = 1, mode=0) where A<:AbstractArray
9595
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
96-
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
96+
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_, mode=mode)
9797
end
9898

9999
depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
100-
pad = 0, stride = 1) where T =
101-
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
100+
pad = 0, stride = 1, mode=0) where T =
101+
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode=mode)
102102

103103
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
104104
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)

0 commit comments

Comments
 (0)