@@ -24,10 +24,10 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
24
24
padtuple (x:: Tuple ,p:: Tuple ) = p
25
25
padtuple (x:: AbstractArray ,p) = padtuple (size (x),p)
26
26
27
- function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
27
+ function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , mode = 0 ) where A<: AbstractArray
28
28
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
29
29
conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)),
30
- x, w, pad = pad_, stride = stride_, dilation = dilation)
30
+ x, w, pad = pad_, stride = stride_, dilation = dilation, mode = mode )
31
31
end
32
32
33
33
∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
@@ -62,8 +62,8 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
62
62
end
63
63
64
64
conv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
65
- pad = 0 , stride = 1 , dilation = 1 ) where T =
66
- conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
65
+ pad = 0 , stride = 1 , dilation = 1 , mode = 0 ) where T =
66
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = mode )
67
67
68
68
∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
69
69
pad = 0 , stride = 1 , dilation = 1 ) where T =
0 commit comments