Skip to content

Commit e7fd08d

Browse files
committed
allow type of x and w to differ in conv
1 parent d6aaa81 commit e7fd08d

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/conv.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
2424
padtuple(x::Tuple,p::Tuple) = p
2525
padtuple(x::AbstractArray,p) = padtuple(size(x),p)
2626

27-
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
27+
function conv(x::A, w::B; pad = 0, stride = 1, dilation = 1) where {A<:AbstractArray, B<:AbstractArray}
2828
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
2929
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
3030
x, w, pad = pad_, stride = stride_, dilation = dilation)
3131
end
3232

33-
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
33+
∇conv_data(dy::A, x::B, w::C; pad = 0, stride = 1, dilation = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
3434
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
3535

36-
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
36+
∇conv_filter(dy::A, x::B, w::C; pad = 0, stride = 1, dilation = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
3737
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
3838

3939
# N-D dispatch
@@ -91,7 +91,7 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
9191
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
9292
end
9393

94-
function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
94+
function depthwiseconv(x::A, w::B; pad = 0, stride = 1) where {A<:AbstractArray, B<:AbstractArray}
9595
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
9696
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
9797
end
@@ -100,10 +100,10 @@ depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,
100100
pad = 0, stride = 1) where T =
101101
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
102102

103-
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
103+
∇depthwiseconv_data(dy::A, x::B, w::C; pad = 0, stride = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
104104
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
105105

106-
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
106+
∇depthwiseconv_filter(dy::A, x::B, w::C; pad = 0, stride = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
107107
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
108108

109109
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};

test/conv.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
99
39 89 139;
1010
49 99 149;
1111
59 109 159.]
12+
13+
@test dropdims(conv(view(x, :, :, :, :), w), dims = (3,4)) == [
14+
29 79 129;
15+
39 89 139;
16+
49 99 149;
17+
59 109 159.]
1218

1319
@test dropdims(conv(x, w; stride=2), dims = (3,4)) == [
1420
29 129;

0 commit comments

Comments
 (0)