Skip to content

Commit ccc6dad

Browse files
authored
Merge pull request #63 from jbrea/master
allow type of x and w to differ in conv
2 parents a470e99 + b696529 commit ccc6dad

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/conv.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
5454
padtuple(x::Tuple,p::Tuple) = p
5555
padtuple(x::AbstractArray,p) = padtuple(size(x),p)
5656

57-
function conv(x::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
57+
function conv(x::AbstractArray, w::AbstractArray; size=nothing, pad = 0, stride = 1, dilation = 1)
5858
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
5959
if size === nothing
6060
size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_)
@@ -70,15 +70,15 @@ function crosscor(x::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1) w
7070
crosscor!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation)
7171
end
7272

73-
function ∇conv_data(dy::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1, flipkernel = 0) where A<:AbstractArray
73+
function ∇conv_data(dy::AbstractArray, w::AbstractArray; size=nothing, pad = 0, stride = 1, dilation = 1, flipkernel = 0)
7474
pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation)
7575
if size === nothing
7676
size = ctdims(Base.size(dy), Base.size(w), pad_, stride_, dilation_)
7777
end
7878
∇conv_data!(similar(dy, size), dy, w, pad = pad_, stride = stride_, dilation = dilation_, flipkernel=flipkernel)
7979
end
8080

81-
function ∇conv_filter(dy::A, x::A; size = nothing, pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray
81+
function ∇conv_filter(dy::AbstractArray, x::AbstractArray; size = nothing, pad = 0, stride = 1, dilation = 1, flipkernel=0)
8282
pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation)
8383
if size === nothing
8484
size = wdims(Base.size(x), Base.size(dy), pad_, stride_, dilation_)
@@ -144,7 +144,7 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
144144
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
145145
end
146146

147-
function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
147+
function depthwiseconv(x::AbstractArray, w::AbstractArray; pad = 0, stride = 1)
148148
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
149149
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
150150
end
@@ -162,10 +162,10 @@ depthwisecrosscor!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArra
162162
pad = 0, stride = 1) where T =
163163
depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1)
164164

165-
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1, flipkernel=0) where A<:AbstractArray =
165+
∇depthwiseconv_data(dy::AbstractArray, x::AbstractArray, w::AbstractArray; pad = 0, stride = 1, flipkernel=0) =
166166
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel)
167167

168-
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, flipkernel=0) where A<:AbstractArray =
168+
∇depthwiseconv_filter(dy::AbstractArray, x::AbstractArray, w::AbstractArray; pad = 0, stride = 1, flipkernel=0) =
169169
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel)
170170

171171
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};

test/conv.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ using NNlib: conv, crosscor, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool,
99
39 89 139;
1010
49 99 149;
1111
59 109 159.]
12+
13+
@test dropdims(conv(view(x, :, :, :, :), w), dims = (3,4)) == [
14+
29 79 129;
15+
39 89 139;
16+
49 99 149;
17+
59 109 159.]
1218

1319
@test dropdims(crosscor(x, w), dims = (3,4)) == [
1420
51 101 151;

0 commit comments

Comments
 (0)