Skip to content

Commit 6ac54c0

Browse files
authored
Merge pull request #34 from DrChainsaw/asymmetric_conv_pad
Add support for asymmetric padding for convlayers
2 parents 451901c + fa5ae79 commit 6ac54c0

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

ext/NNlibCUDA/src/cudnn/conv.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,42 @@ using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
1010

1111
const CUDNNFloat = Union{Float16,Float32,Float64}
1212

13-
function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}) where T
13+
function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T
14+
# The main purpose of this function is to catch asymmetric padding which cudnn does not support
15+
# If we find asymmetric padding we'll make a copy of x which is manually padded so that we can
16+
# call cudnn with symmetric padding.
17+
pad = NNlib.padding(cdims)
18+
sdims = NNlib.spatial_dims(cdims)
19+
all(i -> pad[i] .== pad[i+1], 1:2:2sdims) && return (cudnnConvolutionDescriptor(cdims, x), x, identity)
20+
21+
# Naive implementation, is there a faster way?
22+
# How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top
23+
# and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out
24+
# which side of x to pad. Oh, and we use a CartesianIndex as we will mainly use this to index in x
25+
pad_manual = CartesianIndex(ntuple(i -> i > sdims ? 0 : pad[2(i-1)+1] - pad[2(i-1)+2], ndims(x)))
26+
27+
# How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top
28+
# and pad_bottom etc. respectively
29+
pad_cudnn = ntuple(i -> min(pad[2(i-1)+1], pad[2(i-1)+2]), sdims)
30+
31+
x_padded_size = ntuple(i -> i <= sdims ? size(x, i) + abs(pad_manual[i]) : size(x ,i), ndims(x))
32+
x_padded = similar(x, x_padded_size)
33+
fill!(x_padded, 0)
34+
# This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x
35+
# Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim
36+
# by dim to an array in a loop
37+
xIs = CartesianIndices(x)
38+
xI_first = first(xIs)
39+
xI_last = last(xIs)
40+
xIs_pad = max(xI_first, xI_first + pad_manual) : max(xI_last, xI_last + pad_manual)
41+
x_padded[xIs_pad] = x
42+
43+
return cudnnConvolutionDescriptor(cdims, x_padded, pad_cudnn), x_padded, _x -> _x[xIs_pad]
44+
end
45+
46+
function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pad = nnlibPadding(cdims)) where T
1447
mode=(NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION)
15-
cudnnConvolutionDescriptor(convdims(nnlibPadding(cdims),size(x),0),
48+
cudnnConvolutionDescriptor(convdims(pad, size(x),0),
1649
convdims(NNlib.stride(cdims),size(x),1),
1750
convdims(NNlib.dilation(cdims),size(x),1),
1851
mode,
@@ -30,7 +63,7 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims
3063
if algo != -1
3164
@warn "algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
3265
end
33-
d = cudnnConvolutionDescriptor(cdims, x)
66+
d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
3467
cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)
3568
end
3669

@@ -43,7 +76,7 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{
4376
if algo != -1
4477
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
4578
end
46-
d = cudnnConvolutionDescriptor(cdims, x)
79+
d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
4780
# only relu and identity are supported by cudnnConvolutionForward!
4881
activation === NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)
4982
cudnnConvolutionForward!(y, w, x, d; z, bias, activation, alpha, beta)
@@ -62,13 +95,13 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
6295
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
6396
end
6497
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
98+
convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)
6599
xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
66-
convDesc = cudnnConvolutionDescriptor(cdims, dx)
67100
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx)
68101
with_workspace(p.memory) do workspace
69102
cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)
70103
end
71-
return dx
104+
return depad(dx)
72105
end
73106

74107
function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
@@ -80,8 +113,8 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
80113
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
81114
end
82115
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
116+
convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
83117
xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
84-
convDesc = cudnnConvolutionDescriptor(cdims, x)
85118
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw);
86119
with_workspace(p.memory) do workspace
87120
cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);

ext/NNlibCUDA/test/conv.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ using NNlib: DenseConvDims
1212
options = Dict{Any, Any}.((
1313
(), (:dilation => 2), (:flipkernel => true), (:stride => 2),
1414
(:padding => 1),
15+
(:padding => (1,0)),
16+
(:padding => (0,1)),
17+
(:padding => (2,3)),
1518
))
1619
C_in_ = 3
1720
C_out = 4
@@ -26,6 +29,14 @@ using NNlib: DenseConvDims
2629

2730
for opts in options
2831
opts[:groups] = groups
32+
33+
if :padding in keys(opts)
34+
padding = opts[:padding]
35+
if 1 < length(padding) && length(padding) != 2num_spatial_dims
36+
opts[:padding] = ntuple(i -> padding[mod1(i,2)] .+ 2div(i-1,2), 2num_spatial_dims)
37+
end
38+
end
39+
2940
cdims = DenseConvDims(x, w; opts...)
3041
y = NNlib.conv(x, w, cdims)
3142

@@ -44,5 +55,4 @@ using NNlib: DenseConvDims
4455
gputest((w, x, y) -> NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=2.0), w, x, y, checkgrad=false) # TODO
4556
end
4657
end
47-
4858
end

0 commit comments

Comments
 (0)