@@ -24,16 +24,16 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
24
24
padtuple (x:: Tuple ,p:: Tuple ) = p
25
25
padtuple (x:: AbstractArray ,p) = padtuple (size (x),p)
26
26
27
- function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
27
+ function conv (x:: A , w:: B ; pad = 0 , stride = 1 , dilation = 1 ) where { A<: AbstractArray , B <: AbstractArray }
28
28
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
29
29
conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
30
30
x, w, pad = pad_, stride = stride_, dilation = dilation)
31
31
end
32
32
33
- ∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
33
+ ∇conv_data (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , dilation = 1 ) where { A<: AbstractArray , B <: AbstractArray , C <: AbstractArray } =
34
34
∇conv_data! (zero (x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
35
35
36
- ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
36
+ ∇conv_filter (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 , dilation = 1 ) where { A<: AbstractArray , B <: AbstractArray , C <: AbstractArray } =
37
37
∇conv_filter! (zero (w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
38
38
39
39
# N-D dispatch
@@ -91,7 +91,7 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
91
91
((x[1 ] + 2 * pad[1 ] - w[1 ])÷ stride[1 ] + 1 ,(x[2 ] + 2 * pad[2 ] - w[2 ])÷ stride[2 ] + 1 ,w[3 ]* w[4 ],x[4 ])
92
92
end
93
93
94
- function depthwiseconv (x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray
94
+ function depthwiseconv (x:: A , w:: B ; pad = 0 , stride = 1 ) where { A<: AbstractArray , B <: AbstractArray }
95
95
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
96
96
depthwiseconv! (similar (x, dcdims (size (x), size (w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97
97
end
@@ -100,10 +100,10 @@ depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,
100
100
pad = 0 , stride = 1 ) where T =
101
101
depthwiseconv2d! (y, x, w, padding = pad, stride = stride)
102
102
103
- ∇depthwiseconv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
103
+ ∇depthwiseconv_data (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 ) where { A<: AbstractArray , B <: AbstractArray , C <: AbstractArray } =
104
104
∇depthwiseconv_data! (zero (x), dy, x, w; pad = pad, stride = stride)
105
105
106
- ∇depthwiseconv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
106
+ ∇depthwiseconv_filter (dy:: A , x:: B , w:: C ; pad = 0 , stride = 1 ) where { A<: AbstractArray , B <: AbstractArray , C <: AbstractArray } =
107
107
∇depthwiseconv_filter! (zero (w), dy, x, w; pad = pad, stride = stride)
108
108
109
109
∇depthwiseconv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
0 commit comments