Skip to content

Commit 513ff01

Browse files
authored
Merge pull request #40 from tejank10/dilation
Dilation support for Conv
2 parents 2cdf41b + e0c216e commit 513ff01

File tree

3 files changed

+120
-91
lines changed

3 files changed

+120
-91
lines changed

src/conv.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,66 +24,66 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
2424
padtuple(x::Tuple,p::Tuple) = p
2525
padtuple(x::AbstractArray,p) = padtuple(size(x),p)
2626

27-
function conv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
27+
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
2828
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
29-
conv!(similar(x, cdims(size(x), size(w), pad_, stride_)),
30-
x, w, pad = pad_, stride = stride_)
29+
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
30+
x, w, pad = pad_, stride = stride_, dilation = dilation)
3131
end
3232

33-
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
34-
∇conv_data!(zeros(x), dy, x, w; pad = pad, stride = stride)
33+
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
34+
∇conv_data!(zeros(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
3535

36-
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
37-
∇conv_filter!(zeros(w), dy, x, w; pad = pad, stride = stride)
36+
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
37+
∇conv_filter!(zeros(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
3838

3939
# N-D dispatch
4040

4141
function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
42-
pad = 0, stride = 1) where T
42+
pad = 0, stride = 1, dilation = 1) where T
4343
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
44-
conv!(args..., pad = (pad...,0), stride = (stride...,1))
44+
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
4545
return y
4646
end
4747

4848
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
4949
x::AbstractArray{T,3}, w::AbstractArray{T,3};
50-
pad = 0, stride = 1) where T
50+
pad = 0, stride = 1, dilation = 1) where T
5151
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
52-
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1))
52+
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
5353
return dw
5454
end
5555

5656
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
5757
x::AbstractArray{T,3}, w::AbstractArray{T,3};
58-
pad = 0, stride = 1) where T
58+
pad = 0, stride = 1, dilation = 1) where T
5959
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
60-
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1))
60+
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1))
6161
return dx
6262
end
6363

6464
conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
65-
pad = 0, stride = 1) where T =
66-
conv2d!(y, x, w, padding = pad, stride = stride)
65+
pad = 0, stride = 1, dilation = 1) where T =
66+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
6767

6868
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
69-
pad = 0, stride = 1) where T =
70-
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)
69+
pad = 0, stride = 1, dilation = 1) where T =
70+
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
7171

7272
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
73-
pad = 0, stride = 1) where T =
74-
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)
73+
pad = 0, stride = 1, dilation = 1) where T =
74+
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
7575

7676
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
77-
pad = 0, stride = 1) where T =
78-
conv3d!(y, x, w, padding = pad, stride = stride)
77+
pad = 0, stride = 1, dilation = 1) where T =
78+
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
7979

8080
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
81-
pad = 0, stride = 1) where T =
82-
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)
81+
pad = 0, stride = 1, dilation = 1) where T =
82+
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
8383

8484
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
85-
pad = 0, stride = 1) where T =
86-
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)
85+
pad = 0, stride = 1, dilation = 1) where T =
86+
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
8787

8888
# Pooling
8989

0 commit comments

Comments
 (0)