diff --git a/src/dnn/compat.jl b/src/dnn/compat.jl index e81b96cc..ade757bb 100644 --- a/src/dnn/compat.jl +++ b/src/dnn/compat.jl @@ -1,11 +1,11 @@ # Compatibility shims until users upgrade to new NNlib format function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}; pad=0, stride=1, flipkernel=0, dilation=1, kwargs...) where {T<:CUDNNFloat} - cdims = DenseConvDims(x, w; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation) + cdims = ConvDims(x, w; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation) return conv!(y, x, w, cdims; kwargs...) end function ∇conv_filter!(dw::CuArray{T}, dy::CuArray{T}, x::CuArray{T}; pad=0, stride=1, flipkernel=0, dilation=1, kwargs...) where {T<:CUDNNFloat} - cdims = DenseConvDims(x, dw; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation) + cdims = ConvDims(x, dw; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation) # NOTE!!! This compat shim re-arranges the argument order! return ∇conv_filter!(dw, x, dy, cdims; kwargs...) end diff --git a/src/dnn/conv.jl b/src/dnn/conv.jl index b0f3a9db..c0b46719 100644 --- a/src/dnn/conv.jl +++ b/src/dnn/conv.jl @@ -1,4 +1,4 @@ -using NNlib: DenseConvDims +using NNlib: ConvDims # descriptor @@ -28,7 +28,7 @@ end Base.cconvert(::Type{cudnnConvolutionMode_t}, x::Bool) = x ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION -function ConvDesc(T, N, padding, stride, dilation, mode) +function ConvDesc(T, N, padding, stride, dilation, mode, groupcount) cd = Ref{cudnnConvolutionDescriptor_t}() cudnnCreateConvolutionDescriptor(cd) if version() >= v"4" @@ -38,18 +38,19 @@ function ConvDesc(T, N, padding, stride, dilation, mode) else cudnnSetConvolutionNdDescriptor(cd[],N,cdsize(padding,N),cdsize(stride,N),cdsize(dilation,N),mode) end + cudnnSetConvolutionGroupCount(cd[], Cint(groupcount)) this = ConvDesc(cd[]) finalizer(unsafe_free!, this) return this end -function ConvDesc(T, cdims::DenseConvDims) +function ConvDesc(T, cdims::ConvDims) pd = NNlib.padding(cdims) if !all(pd[1:2:end] .== pd[2:2:end]) @warn("CuDNN does not support asymmetric padding; defaulting to symmetric choice") end return ConvDesc(T, NNlib.spatial_dims(cdims), pd[1:2:end], NNlib.stride(cdims), - NNlib.dilation(cdims), NNlib.flipkernel(cdims)) + NNlib.dilation(cdims), NNlib.flipkernel(cdims), NNlib.group_count(cdims)) end @@ -68,7 +69,7 @@ function cudnnConvolutionBiasActivationForward(y::CuArray{T,N}, x::CuArray{T,N}, end function cudnnConvolutionForward(y::CuArray{T,N}, x::CuArray{T,N}, w::CuArray{T,N}, - cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N} + cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N} @workspace size=@argout( cudnnGetConvolutionForwardWorkspaceSize( handle(), TensorDesc(x), @@ -86,7 +87,7 @@ function cudnnConvolutionForward(y::CuArray{T,N}, x::CuArray{T,N}, w::CuArray{T, end function cudnnConvolutionBackwardData(dx::CuArray{T,N}, w::CuArray{T,N}, dy::CuArray{T,N}, - cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N} + cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N} @workspace size=@argout( cudnnGetConvolutionBackwardDataWorkspaceSize( handle(), FilterDesc(w), @@ -105,7 +106,7 @@ function cudnnConvolutionBackwardData(dx::CuArray{T,N}, w::CuArray{T,N}, dy::CuA end function cudnnConvolutionBackwardFilter(dw::CuArray{T,N}, x::CuArray{T,N}, dy::CuArray{T,N}, - cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N} + cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N} @workspace size=@argout( cudnnGetConvolutionBackwardFilterWorkspaceSize( handle(), TensorDesc(x), diff --git a/src/dnn/nnlib.jl b/src/dnn/nnlib.jl index 65f8c2cc..7b2d3e4a 100644 --- a/src/dnn/nnlib.jl +++ b/src/dnn/nnlib.jl @@ -41,7 +41,7 @@ end # Convolution -function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::DenseConvDims; +function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::ConvDims; alpha=1, algo=0) where T<:CUDNNFloat if version() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") @@ -51,7 +51,7 @@ function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::DenseConvDims end function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, - cdims::DenseConvDims; alpha=1, algo=0) where T<:CUDNNFloat + cdims::ConvDims; alpha=1, algo=0) where T<:CUDNNFloat if version() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end @@ -60,7 +60,7 @@ function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, end function ∇conv_data!(dx::CuArray{T}, dy::CuArray{T}, w::CuArray{T}, - cdims::DenseConvDims; alpha=1, algo=0) where T<:CUDNNFloat + cdims::ConvDims; alpha=1, algo=0) where T<:CUDNNFloat if version() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end diff --git a/test/dnn.jl b/test/dnn.jl index cbed6394..b7c15d89 100644 --- a/test/dnn.jl +++ b/test/dnn.jl @@ -17,7 +17,7 @@ else softmax, ∇softmax, logsoftmax, ∇logsoftmax a, b, c = rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4), rand(Float64, 9, 9, 4, 1) da, db, dc = CuArray(a), CuArray(b), CuArray(c) - cdims = DenseConvDims(a, b) + cdims = ConvDims(a, b) @test NNlib.conv(a, b, cdims) ≈ collect(NNlib.conv(da, db, cdims)) @test ∇conv_data(c, b, cdims) ≈ collect(∇conv_data(dc, db, cdims)) @test ∇conv_filter(a, c, cdims) ≈ collect(∇conv_filter(da, dc, cdims)) @@ -35,7 +35,7 @@ else algos = (1, 0, 1, 1,) for (opts, algo) in zip(options, algos) - cdims = DenseConvDims(x, w; opts...) + cdims = ConvDims(x, w; opts...) y = NNlib.conv(x, w, cdims) # Test that basic convolution is equivalent across GPU/CPU