@@ -10,9 +10,42 @@ using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
10
10
11
11
const CUDNNFloat = Union{Float16,Float32,Float64}
12
12
13
- function cudnnConvolutionDescriptor (cdims:: DenseConvDims , x:: DenseCuArray{T} ) where T
13
+ function cudnnConvolutionDescriptorAndPaddedInput (cdims:: DenseConvDims , x:: DenseCuArray{T} ) where T
14
+ # The main purpose of this function is to catch asymmetric padding which cudnn does not support
15
+ # If we find asymmetric padding we'll make a copy of x which is manually padded so that we can
16
+ # call cudnn with symmetric padding.
17
+ pad = NNlib. padding (cdims)
18
+ sdims = NNlib. spatial_dims (cdims)
19
+ all (i -> pad[i] .== pad[i+ 1 ], 1 : 2 : 2 sdims) && return (cudnnConvolutionDescriptor (cdims, x), x, identity)
20
+
21
+ # Naive implementation, is there a faster way?
22
+ # How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top
23
+ # and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out
24
+ # which side of x to pad. Oh, and we use a CartesianIndex as we will mainly use this to index in x
25
+ pad_manual = CartesianIndex (ntuple (i -> i > sdims ? 0 : pad[2 (i- 1 )+ 1 ] - pad[2 (i- 1 )+ 2 ], ndims (x)))
26
+
27
+ # How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top
28
+ # and pad_bottom etc. respectively
29
+ pad_cudnn = ntuple (i -> min (pad[2 (i- 1 )+ 1 ], pad[2 (i- 1 )+ 2 ]), sdims)
30
+
31
+ x_padded_size = ntuple (i -> i <= sdims ? size (x, i) + abs (pad_manual[i]) : size (x ,i), ndims (x))
32
+ x_padded = similar (x, x_padded_size)
33
+ fill! (x_padded, 0 )
34
+ # This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x
35
+ # Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim
36
+ # by dim to an array in a loop
37
+ xIs = CartesianIndices (x)
38
+ xI_first = first (xIs)
39
+ xI_last = last (xIs)
40
+ xIs_pad = max (xI_first, xI_first + pad_manual) : max (xI_last, xI_last + pad_manual)
41
+ x_padded[xIs_pad] = x
42
+
43
+ return cudnnConvolutionDescriptor (cdims, x_padded, pad_cudnn), x_padded, _x -> _x[xIs_pad]
44
+ end
45
+
46
+ function cudnnConvolutionDescriptor (cdims:: DenseConvDims , x:: DenseCuArray{T} , pad = nnlibPadding (cdims)) where T
14
47
mode= (NNlib. flipkernel (cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION)
15
- cudnnConvolutionDescriptor (convdims (nnlibPadding (cdims), size (x),0 ),
48
+ cudnnConvolutionDescriptor (convdims (pad, size (x),0 ),
16
49
convdims (NNlib. stride (cdims),size (x),1 ),
17
50
convdims (NNlib. dilation (cdims),size (x),1 ),
18
51
mode,
@@ -30,7 +63,7 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims
30
63
if algo != - 1
31
64
@warn " algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
32
65
end
33
- d = cudnnConvolutionDescriptor (cdims, x)
66
+ d, x, _ = cudnnConvolutionDescriptorAndPaddedInput (cdims, x)
34
67
cudnnConvolutionForward! (y, w, x, d; alpha, beta, z= y)
35
68
end
36
69
@@ -43,7 +76,7 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{
43
76
if algo != - 1
44
77
@warn " The algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
45
78
end
46
- d = cudnnConvolutionDescriptor (cdims, x)
79
+ d, x, _ = cudnnConvolutionDescriptorAndPaddedInput (cdims, x)
47
80
# only relu and identity are supported by cudnnConvolutionForward!
48
81
activation = (σ == NNlib. relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)
49
82
cudnnConvolutionForward! (y, w, x, d; z, bias, activation, alpha, beta)
@@ -62,13 +95,13 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
62
95
@warn " The algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
63
96
end
64
97
alpha, beta = scalingParameter (T,alpha), scalingParameter (T,beta);
98
+ convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput (cdims, dx)
65
99
xDesc, yDesc, wDesc = cudnnTensorDescriptor (dx), cudnnTensorDescriptor (dy), cudnnFilterDescriptor (w)
66
- convDesc = cudnnConvolutionDescriptor (cdims, dx)
67
100
p = cudnnConvolutionBwdDataAlgoPerf (wDesc, w, yDesc, dy, convDesc, xDesc, dx)
68
101
with_workspace (p. memory) do workspace
69
102
cudnnConvolutionBackwardData (handle (), alpha, wDesc, w, yDesc, dy, convDesc, p. algo, workspace, sizeof (workspace), beta, xDesc, dx)
70
103
end
71
- return dx
104
+ return depad (dx)
72
105
end
73
106
74
107
function ∇conv_filter! (dw:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
@@ -80,8 +113,8 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
80
113
@warn " The algo option has been deprecated, the fastest algo is computed automatically" maxlog= 1
81
114
end
82
115
alpha, beta = scalingParameter (T,alpha), scalingParameter (T,beta);
116
+ convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput (cdims, x)
83
117
xDesc, yDesc, wDesc = cudnnTensorDescriptor (x), cudnnTensorDescriptor (dy), cudnnFilterDescriptor (dw)
84
- convDesc = cudnnConvolutionDescriptor (cdims, x)
85
118
p = cudnnConvolutionBwdFilterAlgoPerf (xDesc, x, yDesc, dy, convDesc, wDesc, dw);
86
119
with_workspace (p. memory) do workspace
87
120
cudnnConvolutionBackwardFilter (handle (), alpha, xDesc, x, yDesc, dy, convDesc, p. algo, workspace, sizeof (workspace), beta, wDesc, dw);
0 commit comments