@@ -10,9 +10,40 @@ using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
10
10
11
11
const CUDNNFloat = Union{Float16,Float32,Float64}
12
12
13
- function cudnnConvolutionDescriptor (cdims:: DenseConvDims , x:: DenseCuArray{T} ) where T
13
+ function cudnnConvolutionDescriptorAndPaddedInput (cdims:: DenseConvDims , x:: DenseCuArray{T} ) where T
14
+ # The main purpose of this function is to catch asymmetric padding which cudnn does not support
15
+ # If we find asymmetric padding we'll make a copy of x which is manually padded so that we can
16
+ # call cudnn with symmetric padding.
17
+ pad = collect (NNlib. padding (cdims)) # work with an array to make things more type stable
18
+ all (pad[1 : 2 : end ] .== pad[2 : 2 : end ]) && return (cudnnConvolutionDescriptor (cdims, x), x, identity)
19
+
20
+ # Maybe we should warn the user that this copies data, but other ML libs generally don't warn
21
+ sdims = NNlib. spatial_dims (cdims)
22
+
23
+ # Naive implementation, is there a faster way?
24
+ # How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top
25
+ # and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out
26
+ # which side of x to pad.
27
+ pad_manual = pad[1 : 2 : 2 sdims] .- pad[2 : 2 : 2 sdims]
28
+ # How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top
29
+ # and pad_bottom etc. respectively
30
+ pad_cudnn = min .(pad[1 : 2 : 2 sdims], pad[2 : 2 : 2 sdims])
31
+
32
+
33
+ x_padded = similar (x, (size (x)[1 : sdims] .+ abs .(pad_manual)). .. , size (x)[end - 1 : end ]. .. )
34
+ # We could do the same yucky indexing stuff for the zeros too so we don't have to write zeros in the whole array.
35
+ # Not sure if it is worth it though...
36
+ fill! (x_padded, 0 )
37
+ # This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x_inds
38
+ # Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim by dim in a loop
39
+ x_inds = range .(1 .+ max .(0 , pad_manual), size (x)[1 : sdims] .- min .(0 , .- pad_manual))
40
+ x_padded[x_inds... , :, :] = x
41
+ return cudnnConvolutionDescriptor (cdims, x_padded, pad_cudnn), x_padded, _x -> _x[x_inds... ,:,:]
42
+ end
43
+
44
+ function cudnnConvolutionDescriptor (cdims:: DenseConvDims , x:: DenseCuArray{T} , pad = nnlibPadding (cdims)) where T
14
45
mode= (NNlib. flipkernel (cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION)
15
- cudnnConvolutionDescriptor (convdims (nnlibPadding (cdims), size (x),0 ),
46
+ cudnnConvolutionDescriptor (convdims (pad, size (x),0 ),
16
47
convdims (NNlib. stride (cdims),size (x),1 ),
17
48
convdims (NNlib. dilation (cdims),size (x),1 ),
18
49
mode,
@@ -30,7 +61,7 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims
30
61
if algo != - 1
31
62
@warn " algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
32
63
end
33
- d = cudnnConvolutionDescriptor (cdims, x)
64
+ d, x, _ = cudnnConvolutionDescriptorAndPaddedInput (cdims, x)
34
65
cudnnConvolutionForward! (y, w, x, d; alpha, beta, z= y)
35
66
end
36
67
@@ -43,7 +74,7 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{
43
74
if algo != - 1
44
75
@warn " The algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
45
76
end
46
- d = cudnnConvolutionDescriptor (cdims, x)
77
+ d, x, _ = cudnnConvolutionDescriptorAndPaddedInput (cdims, x)
47
78
# only relu and identity are supported by cudnnConvolutionForward!
48
79
activation = (σ == NNlib. relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)
49
80
cudnnConvolutionForward! (y, w, x, d; z, bias, activation, alpha, beta)
@@ -62,13 +93,13 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
62
93
@warn " The algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
63
94
end
64
95
alpha, beta = scalingParameter (T,alpha), scalingParameter (T,beta);
96
+ convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput (cdims, dx)
65
97
xDesc, yDesc, wDesc = cudnnTensorDescriptor (dx), cudnnTensorDescriptor (dy), cudnnFilterDescriptor (w)
66
- convDesc = cudnnConvolutionDescriptor (cdims, dx)
67
98
p = cudnnConvolutionBwdDataAlgoPerf (wDesc, w, yDesc, dy, convDesc, xDesc, dx)
68
99
with_workspace (p. memory) do workspace
69
100
cudnnConvolutionBackwardData (handle (), alpha, wDesc, w, yDesc, dy, convDesc, p. algo, workspace, sizeof (workspace), beta, xDesc, dx)
70
101
end
71
- return dx
102
+ return depad (dx)
72
103
end
73
104
74
105
function ∇conv_filter! (dw:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
@@ -80,8 +111,8 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
80
111
@warn " The algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
81
112
end
82
113
alpha, beta = scalingParameter (T,alpha), scalingParameter (T,beta);
114
+ convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput (cdims, x)
83
115
xDesc, yDesc, wDesc = cudnnTensorDescriptor (x), cudnnTensorDescriptor (dy), cudnnFilterDescriptor (dw)
84
- convDesc = cudnnConvolutionDescriptor (cdims, x)
85
116
p = cudnnConvolutionBwdFilterAlgoPerf (xDesc, x, yDesc, dy, convDesc, wDesc, dw);
86
117
with_workspace (p. memory) do workspace
87
118
cudnnConvolutionBackwardFilter (handle (), alpha, xDesc, x, yDesc, dy, convDesc, p. algo, workspace, sizeof (workspace), beta, wDesc, dw);
0 commit comments