Skip to content

Commit 7c2fd61

Browse files
committed
fix data layout
1 parent e825a49 commit 7c2fd61

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

paddle/operators/pool_cudnn_op.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
5252
ScopedTensorDescriptor input_desc;
5353
ScopedTensorDescriptor output_desc;
5454
ScopedPoolingDescriptor pool_desc;
55-
DataLayout layout = DataLayout::kNCHW;
55+
DataLayout layout;
56+
57+
if (strides.size() == 2U) {
58+
layout = DataLayout::kNCHW;
59+
} else {
60+
layout = DataLayout::kNCDHW;
61+
}
5662

5763
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
5864
layout, framework::vectorize2int(input->dims()));
@@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
112118
ScopedTensorDescriptor input_desc;
113119
ScopedTensorDescriptor output_desc;
114120
ScopedPoolingDescriptor pool_desc;
115-
DataLayout layout = DataLayout::kNCHW;
121+
DataLayout layout;
122+
123+
if (strides.size() == 2U) {
124+
layout = DataLayout::kNCHW;
125+
} else {
126+
layout = DataLayout::kNCDHW;
127+
}
116128

117129
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
118130
layout, framework::vectorize2int(input->dims()));

0 commit comments

Comments
 (0)