Skip to content

Commit 74912c7

Browse files
committed
fix data layout
1 parent c5330fa commit 74912c7

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

paddle/operators/conv_transpose_cudnn_op.cu

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
5454
ScopedTensorDescriptor output_desc;
5555
ScopedFilterDescriptor filter_desc;
5656
ScopedConvolutionDescriptor conv_desc;
57-
DataLayout layout = DataLayout::kNCHW;
57+
DataLayout layout;
58+
59+
if (strides.size() == 2U) {
60+
layout = DataLayout::kNCHW;
61+
} else {
62+
layout = DataLayout::kNCDHW;
63+
}
5864

59-
// N, M, H, W
65+
// (N, M, H, W) or (N, M, D, H, W)
6066
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
6167
layout, framework::vectorize2int(input->dims()));
62-
// N, C, O_h, O_w
68+
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
6369
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
6470
layout, framework::vectorize2int(output->dims()));
65-
// M, C, K_h, K_w
71+
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
6672
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
6773
layout, framework::vectorize2int(filter->dims()));
6874
cudnnConvolutionDescriptor_t cudnn_conv_desc =
@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
136142
ScopedConvolutionDescriptor conv_desc;
137143
DataLayout layout = DataLayout::kNCHW;
138144

139-
// Input: (N, M, H, W)
145+
// Input: (N, M, H, W) or (N, M, D, H, W)
140146
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
141147
layout, framework::vectorize2int(input->dims()));
142-
// Output: (N, C, O_H, O_W)
148+
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
143149
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
144150
layout, framework::vectorize2int(output_grad->dims()));
145-
// Filter (M, C, K_H, K_W)
151+
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
146152
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
147153
layout, framework::vectorize2int(filter->dims()));
148154

paddle/platform/cudnn_helper.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -63,9 +60,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
6360
} \
6461
} while (false)
6562

66-
enum class DataLayout {
63+
enum class DataLayout { // Not use
6764
kNHWC,
6865
kNCHW,
66+
kNCDHW,
6967
kNCHW_VECT_C,
7068
};
7169

@@ -107,12 +105,15 @@ class CudnnDataType<double> {
107105
}
108106
};
109107

110-
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
108+
inline cudnnTensorFormat_t GetCudnnTensorFormat(
109+
const DataLayout& order) { // Not use
111110
switch (order) {
112111
case DataLayout::kNHWC:
113112
return CUDNN_TENSOR_NHWC;
114113
case DataLayout::kNCHW:
115114
return CUDNN_TENSOR_NCHW;
115+
case DataLayout::kNCDHW:
116+
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
116117
default:
117118
PADDLE_THROW("Unknown cudnn equivalent for order");
118119
}
@@ -139,7 +140,7 @@ class ScopedTensorDescriptor {
139140
strides[i] = dims[i + 1] * strides[i + 1];
140141
}
141142
// Update tensor descriptor dims setting if groups > 1
142-
// FIXME(typhoonzero): Assume using NCHW order
143+
// FIXME(typhoonzero): Assume using NCHW or NCDHW order
143144
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
144145
if (groups > 1) {
145146
dims_with_group[1] = dims_with_group[1] / groups;

0 commit comments

Comments
 (0)