File tree Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
52
52
ScopedTensorDescriptor input_desc;
53
53
ScopedTensorDescriptor output_desc;
54
54
ScopedPoolingDescriptor pool_desc;
55
- DataLayout layout = DataLayout::kNCHW ;
55
+ DataLayout layout;
56
+
57
+ if (strides.size () == 2U ) {
58
+ layout = DataLayout::kNCHW ;
59
+ } else {
60
+ layout = DataLayout::kNCDHW ;
61
+ }
56
62
57
63
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
58
64
layout, framework::vectorize2int (input->dims ()));
@@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
112
118
ScopedTensorDescriptor input_desc;
113
119
ScopedTensorDescriptor output_desc;
114
120
ScopedPoolingDescriptor pool_desc;
115
- DataLayout layout = DataLayout::kNCHW ;
121
+ DataLayout layout;
122
+
123
+ if (strides.size () == 2U ) {
124
+ layout = DataLayout::kNCHW ;
125
+ } else {
126
+ layout = DataLayout::kNCDHW ;
127
+ }
116
128
117
129
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
118
130
layout, framework::vectorize2int (input->dims ()));
You can’t perform that action at this time.
0 commit comments