Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 47fdd49

Browse files
committed
Using groupcount parameter for depthwise and groupwise convolutions
1 parent b38db15 commit 47fdd49

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/dnn/conv.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
Base.cconvert(::Type{cudnnConvolutionMode_t}, x::Bool) = x ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION
3030

31-
function ConvDesc(T, N, padding, stride, dilation, mode)
31+
function ConvDesc(T, N, padding, stride, dilation, mode, groupcount)
3232
cd = Ref{cudnnConvolutionDescriptor_t}()
3333
cudnnCreateConvolutionDescriptor(cd)
3434
if version() >= v"4"
@@ -38,6 +38,7 @@ function ConvDesc(T, N, padding, stride, dilation, mode)
3838
else
3939
cudnnSetConvolutionNdDescriptor(cd[],N,cdsize(padding,N),cdsize(stride,N),cdsize(dilation,N),mode)
4040
end
41+
cudnnSetConvolutionGroupCount(cd[], Cint(groupcount))
4142
this = ConvDesc(cd[])
4243
finalizer(unsafe_free!, this)
4344
return this
@@ -49,7 +50,7 @@ function ConvDesc(T, cdims::DenseConvDims)
4950
@warn("CuDNN does not support asymmetric padding; defaulting to symmetric choice")
5051
end
5152
return ConvDesc(T, NNlib.spatial_dims(cdims), pd[1:2:end], NNlib.stride(cdims),
52-
NNlib.dilation(cdims), NNlib.flipkernel(cdims))
53+
NNlib.dilation(cdims), NNlib.flipkernel(cdims), NNlib.group_count(cdims))
5354
end
5455

5556

0 commit comments

Comments
 (0)