Skip to content

Commit f3edb3b

Browse files
committed
Use ntuple instead of slicing/broadcasting to improve type stability
1 parent 2ec07fc commit f3edb3b

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

ext/NNlibCUDA/src/cudnn/conv.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,26 @@ function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::Dense
1414
# The main purpose of this function is to catch asymmetric padding which cudnn does not support
1515
# If we find asymmetric padding we'll make a copy of x which is manually padded so that we can
1616
# call cudnn with symmetric padding.
17-
pad = collect(NNlib.padding(cdims)) # work with an array to make things more type stable
18-
all(pad[1:2:end] .== pad[2:2:end]) && return (cudnnConvolutionDescriptor(cdims, x), x, identity)
19-
20-
# Maybe we should warn the user that this copies data, but other ML libs generally don't warn
17+
pad = NNlib.padding(cdims)
2118
sdims = NNlib.spatial_dims(cdims)
22-
19+
all(i -> pad[i] .== pad[i+1], 1:2:2sdims) && return (cudnnConvolutionDescriptor(cdims, x), x, identity)
20+
2321
# Naive implementation, is there a faster way?
2422
# How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top
2523
# and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out
2624
# which side of x to pad.
27-
pad_manual = pad[1:2:2sdims] .- pad[2:2:2sdims]
25+
pad_manual = ntuple(i -> pad[2(i-1)+1] - pad[2(i-1)+2], sdims)
2826
# How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top
2927
# and pad_bottom etc. respectively
30-
pad_cudnn = min.(pad[1:2:2sdims], pad[2:2:2sdims])
28+
pad_cudnn = ntuple(i -> min(pad[2(i-1)+1], pad[2(i-1)+2]), sdims)
3129

30+
x_padded_size = ntuple(i -> i <= sdims ? size(x, i) + abs(pad_manual[i]) : size(x ,i), ndims(x))
3231

33-
x_padded = similar(x, (size(x)[1:sdims] .+ abs.(pad_manual))..., size(x)[end-1:end]...)
34-
# We could do the same yucky indexing stuff for the zeros too so we don't have to write zeros in the whole array.
35-
# Not sure if it is worth it though...
32+
x_padded = similar(x, x_padded_size)
3633
fill!(x_padded, 0)
3734
# This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x_inds
38-
# Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim by dim in a loop
39-
x_inds = range.(1 .+ max.(0, pad_manual), size(x)[1:sdims] .- min.(0, .-pad_manual))
35+
# Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim by dim in a loop
36+
x_inds = ntuple(i -> range(1 + max(0, pad_manual[i]), size(x,i) - min(0, -pad_manual[i])), sdims)
4037
x_padded[x_inds..., :, :] = x
4138
return cudnnConvolutionDescriptor(cdims, x_padded, pad_cudnn), x_padded, _x -> _x[x_inds...,:,:]
4239
end

0 commit comments

Comments
 (0)