Skip to content

Commit 2ec07fc

Browse files
committed
Add support for asymmetric padding for convlayers
1 parent 451901c commit 2ec07fc

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

ext/NNlibCUDA/src/cudnn/conv.jl

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,40 @@ using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
1010

1111
const CUDNNFloat = Union{Float16,Float32,Float64}
1212

13-
function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}) where T
13+
function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T
14+
# The main purpose of this function is to catch asymmetric padding which cudnn does not support
15+
# If we find asymmetric padding we'll make a copy of x which is manually padded so that we can
16+
# call cudnn with symmetric padding.
17+
pad = collect(NNlib.padding(cdims)) # work with an array to make things more type stable
18+
all(pad[1:2:end] .== pad[2:2:end]) && return (cudnnConvolutionDescriptor(cdims, x), x, identity)
19+
20+
# Maybe we should warn the user that this copies data, but other ML libs generally don't warn
21+
sdims = NNlib.spatial_dims(cdims)
22+
23+
# Naive implementation, is there a faster way?
24+
# How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top
25+
# and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out
26+
# which side of x to pad.
27+
pad_manual = pad[1:2:2sdims] .- pad[2:2:2sdims]
28+
# How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top
29+
# and pad_bottom etc. respectively
30+
pad_cudnn = min.(pad[1:2:2sdims], pad[2:2:2sdims])
31+
32+
33+
x_padded = similar(x, (size(x)[1:sdims] .+ abs.(pad_manual))..., size(x)[end-1:end]...)
34+
# We could do the same yucky indexing stuff for the zeros too so we don't have to write zeros in the whole array.
35+
# Not sure if it is worth it though...
36+
fill!(x_padded, 0)
37+
# This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x_inds
38+
# Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim by dim in a loop
39+
x_inds = range.(1 .+ max.(0, pad_manual), size(x)[1:sdims] .- min.(0, .-pad_manual))
40+
x_padded[x_inds..., :, :] = x
41+
return cudnnConvolutionDescriptor(cdims, x_padded, pad_cudnn), x_padded, _x -> _x[x_inds...,:,:]
42+
end
43+
44+
function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pad = nnlibPadding(cdims)) where T
1445
mode=(NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION)
15-
cudnnConvolutionDescriptor(convdims(nnlibPadding(cdims),size(x),0),
46+
cudnnConvolutionDescriptor(convdims(pad, size(x),0),
1647
convdims(NNlib.stride(cdims),size(x),1),
1748
convdims(NNlib.dilation(cdims),size(x),1),
1849
mode,
@@ -30,7 +61,7 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims
3061
if algo != -1
3162
@warn "algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
3263
end
33-
d = cudnnConvolutionDescriptor(cdims, x)
64+
d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
3465
cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)
3566
end
3667

@@ -43,7 +74,7 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{
4374
if algo != -1
4475
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
4576
end
46-
d = cudnnConvolutionDescriptor(cdims, x)
77+
d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
4778
# only relu and identity are supported by cudnnConvolutionForward!
4879
activation === NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)
4980
cudnnConvolutionForward!(y, w, x, d; z, bias, activation, alpha, beta)
@@ -62,13 +93,13 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
6293
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
6394
end
6495
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
96+
convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)
6597
xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
66-
convDesc = cudnnConvolutionDescriptor(cdims, dx)
6798
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx)
6899
with_workspace(p.memory) do workspace
69100
cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)
70101
end
71-
return dx
102+
return depad(dx)
72103
end
73104

74105
function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
@@ -80,8 +111,8 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
80111
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
81112
end
82113
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
114+
convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
83115
xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
84-
convDesc = cudnnConvolutionDescriptor(cdims, x)
85116
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw);
86117
with_workspace(p.memory) do workspace
87118
cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);

ext/NNlibCUDA/test/conv.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ using NNlib: DenseConvDims
1212
options = Dict{Any, Any}.((
1313
(), (:dilation => 2), (:flipkernel => true), (:stride => 2),
1414
(:padding => 1),
15+
(:padding => (1,0)),
16+
(:padding => (0,1)),
17+
(:padding => (2,3)),
1518
))
1619
C_in_ = 3
1720
C_out = 4
@@ -26,6 +29,14 @@ using NNlib: DenseConvDims
2629

2730
for opts in options
2831
opts[:groups] = groups
32+
33+
if :padding in keys(opts)
34+
padding = opts[:padding]
35+
if 1 < length(padding) && length(padding) != 2num_spatial_dims
36+
opts[:padding] = ntuple(i -> padding[mod1(i,2)] .+ 2div(i-1,2), 2num_spatial_dims)
37+
end
38+
end
39+
2940
cdims = DenseConvDims(x, w; opts...)
3041
y = NNlib.conv(x, w, cdims)
3142

@@ -44,5 +55,4 @@ using NNlib: DenseConvDims
4455
gputest((w, x, y) -> NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=2.0), w, x, y, checkgrad=false) # TODO
4556
end
4657
end
47-
4858
end

0 commit comments

Comments
 (0)