@@ -24,66 +24,66 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
24
24
padtuple (x:: Tuple ,p:: Tuple ) = p
25
25
padtuple (x:: AbstractArray ,p) = padtuple (size (x),p)
26
26
27
- function conv (x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray
27
+ function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray
28
28
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
29
- conv! (similar (x, cdims (size (x), size (w ), pad_, stride_)),
30
- x, w, pad = pad_, stride = stride_)
29
+ conv! (similar (x, cdims (size (x), dilation_dims (w, dilation ), pad_, stride_)),
30
+ x, w, pad = pad_, stride = stride_, dilation = dilation )
31
31
end
32
32
33
- ∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
34
- ∇conv_data! (zeros (x), dy, x, w; pad = pad, stride = stride)
33
+ ∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
34
+ ∇conv_data! (zeros (x), dy, x, w; pad = pad, stride = stride, dilation = dilation )
35
35
36
- ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 ) where A<: AbstractArray =
37
- ∇conv_filter! (zeros (w), dy, x, w; pad = pad, stride = stride)
36
+ ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 ) where A<: AbstractArray =
37
+ ∇conv_filter! (zeros (w), dy, x, w; pad = pad, stride = stride, dilation = dilation )
38
38
39
39
# N-D dispatch
40
40
41
41
function conv! (y:: AbstractArray{T,3} , x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
42
- pad = 0 , stride = 1 ) where T
42
+ pad = 0 , stride = 1 , dilation = 1 ) where T
43
43
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (y, x, w))
44
- conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ))
44
+ conv! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation ... , 1 ) )
45
45
return y
46
46
end
47
47
48
48
function ∇conv_filter! (dw:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
49
49
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
50
- pad = 0 , stride = 1 ) where T
50
+ pad = 0 , stride = 1 , dilation = 1 ) where T
51
51
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dw, dy, x, w))
52
- ∇conv_filter! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ))
52
+ ∇conv_filter! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation ... , 1 ) )
53
53
return dw
54
54
end
55
55
56
56
function ∇conv_data! (dx:: AbstractArray{T,3} , dy:: AbstractArray{T,3} ,
57
57
x:: AbstractArray{T,3} , w:: AbstractArray{T,3} ;
58
- pad = 0 , stride = 1 ) where T
58
+ pad = 0 , stride = 1 , dilation = 1 ) where T
59
59
args = map (x -> reshape (x, size (x,1 ),1 ,size (x,2 ),size (x,3 )), (dx, dy, x, w))
60
- ∇conv_data! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ))
60
+ ∇conv_data! (args... , pad = (pad... ,0 ), stride = (stride... ,1 ), dilation = (dilation ... , 1 ) )
61
61
return dx
62
62
end
63
63
64
64
conv! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
65
- pad = 0 , stride = 1 ) where T =
66
- conv2d! (y, x, w, padding = pad, stride = stride)
65
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
66
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation )
67
67
68
68
∇conv_filter! (dw:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
69
- pad = 0 , stride = 1 ) where T =
70
- conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride)
69
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
70
+ conv2d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation )
71
71
72
72
∇conv_data! (dx:: AbstractArray{T,4} , dy:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
73
- pad = 0 , stride = 1 ) where T =
74
- conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride)
73
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
74
+ conv2d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation )
75
75
76
76
conv! (y:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
77
- pad = 0 , stride = 1 ) where T =
78
- conv3d! (y, x, w, padding = pad, stride = stride)
77
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
78
+ conv3d! (y, x, w, padding = pad, stride = stride, dilation = dilation )
79
79
80
80
∇conv_filter! (dw:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
81
- pad = 0 , stride = 1 ) where T =
82
- conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride)
81
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
82
+ conv3d_grad_w! (dw, x, w, dy, padding = pad, stride = stride, dilation = dilation )
83
83
84
84
∇conv_data! (dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} , x:: AbstractArray{T,5} , w:: AbstractArray{T,5} ;
85
- pad = 0 , stride = 1 ) where T =
86
- conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride)
85
+ pad = 0 , stride = 1 , dilation = 1 ) where T =
86
+ conv3d_grad_x! (dx, x, w, dy, padding = pad, stride = stride, dilation = dilation )
87
87
88
88
# Pooling
89
89
0 commit comments