Skip to content

Commit f8b8811

Browse files
authored
fix_depthwise_conv_cudnn, test=develop (#20712) (#20727)
1 parent 5baf1b2 commit f8b8811

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

paddle/fluid/operators/conv_cudnn_op.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,16 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
265265
algo = search::Find<T>(args, exhaustive_search, false, 0, ctx);
266266
workspace_size = search::GetWorkspaceSize(args, algo);
267267

268+
#if CUDNN_VERSION_MIN(7, 0, 1)
269+
// when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\
270+
// FWD_ALGO_WINOGRAD_NONFUSED, but this kind of algorithm is unstable
271+
// in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\
272+
// FWD_ALGO_IMPLICIT_GEMM manually.
273+
if (ctx.Attr<int>("groups") > 1) {
274+
algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
275+
}
276+
#endif
277+
268278
// ------------------- cudnn conv forward ---------------------
269279
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
270280
for (int i = 0; i < groups; i++) {
@@ -805,6 +815,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
805815
#if CUDNN_VERSION_MIN(7, 0, 1)
806816
iwo_group = 1;
807817
c_group = groups;
818+
groups = 1;
808819
#endif
809820
auto dtype = platform::CudnnDataType<T>::type;
810821

0 commit comments

Comments
 (0)