Skip to content

Commit 416b341

Browse files
zhangting2020Aurelius84
authored andcommitted
[cherry-pick] fix the bug of conv_transpose: compitable with AnyLayout setting, test=release/1.6 #(20897) (#20918)
1 parent 1948210 commit 416b341

File tree

5 files changed

+47
-45
lines changed

5 files changed

+47
-45
lines changed

paddle/fluid/operators/conv_transpose_cudnn_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
316316
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
317317
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
318318
const paddle::operators::DataLayout data_layout =
319-
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC);
319+
(data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC);
320320

321321
// if channel_last, transpose to channel_first
322322
Tensor input_transpose;

paddle/fluid/operators/conv_transpose_op.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
328328
col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
329329
data_layout);
330330
}
331-
output_batch_vec.push_back(out_slice);
331+
if (data_layout == framework::DataLayout::kNHWC) {
332+
output_batch_vec.push_back(out_slice);
333+
}
332334
}
333335
if (data_layout == framework::DataLayout::kNHWC) {
334336
concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),

paddle/fluid/operators/math/depthwise_conv.cu

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
6060
const int w_in_end = w_in_start + filter_width * dilate_width;
6161

6262
int in_offset;
63-
if (data_layout == DataLayout::kNCHW) {
63+
if (data_layout != DataLayout::kNHWC) {
6464
in_offset =
6565
((batch * input_channels + c_in) * input_height) * input_width;
6666
} else {
@@ -78,7 +78,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
7878
if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
7979
w_in < w_end) {
8080
int offset;
81-
if (data_layout == DataLayout::kNCHW) {
81+
if (data_layout != DataLayout::kNHWC) {
8282
offset = in_offset + h_in * input_width + w_in;
8383
} else {
8484
offset = in_offset +
@@ -94,7 +94,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
9494
}
9595
}
9696
int index;
97-
if (data_layout == DataLayout::kNCHW) {
97+
if (data_layout != DataLayout::kNHWC) {
9898
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
9999
output_width +
100100
w_out;
@@ -131,7 +131,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
131131
const int w_in_end = w_in_start + c_filter * dilate_width;
132132

133133
int in_offset;
134-
if (data_layout == DataLayout::kNCHW) {
134+
if (data_layout != DataLayout::kNHWC) {
135135
in_offset =
136136
((batch * input_channels + c_in) * input_height) * input_width;
137137
} else {
@@ -150,7 +150,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
150150
if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
151151
w_in < input_width) {
152152
int offset;
153-
if (data_layout == DataLayout::kNCHW) {
153+
if (data_layout != DataLayout::kNHWC) {
154154
offset = in_offset + h_in * input_width + w_in;
155155
} else {
156156
offset = in_offset +
@@ -166,7 +166,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
166166
}
167167
}
168168
int index;
169-
if (data_layout == DataLayout::kNCHW) {
169+
if (data_layout != DataLayout::kNHWC) {
170170
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
171171
output_width +
172172
w_out;
@@ -252,7 +252,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
252252

253253
T value = 0;
254254
int index;
255-
if (data_layout == DataLayout::kNCHW) {
255+
if (data_layout != DataLayout::kNHWC) {
256256
index =
257257
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
258258
w_in;
@@ -283,7 +283,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
283283
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
284284
s_w_out < output_width) {
285285
int output_grad_offset;
286-
if (data_layout == DataLayout::kNCHW) {
286+
if (data_layout != DataLayout::kNHWC) {
287287
output_grad_offset =
288288
((batch * output_channels + c_out) * output_height +
289289
s_h_out) *
@@ -335,7 +335,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
335335

336336
T value = 0;
337337
int index;
338-
if (data_layout == DataLayout::kNCHW) {
338+
if (data_layout != DataLayout::kNHWC) {
339339
index =
340340
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
341341
w_in;
@@ -363,7 +363,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
363363
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
364364
s_w_out < output_width) {
365365
int output_grad_offset;
366-
if (data_layout == DataLayout::kNCHW) {
366+
if (data_layout != DataLayout::kNHWC) {
367367
output_grad_offset =
368368
((batch * output_channels + c_out) * output_height +
369369
s_h_out) *
@@ -449,7 +449,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
449449
#define gaid_nhwc(N, H, W, C) \
450450
((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
451451
int input_id;
452-
if (data_layout == DataLayout::kNCHW) {
452+
if (data_layout != DataLayout::kNHWC) {
453453
input_id = ((bid * (gridDim.z / filter_multiplier) +
454454
kernel_id / filter_multiplier) *
455455
input_height +
@@ -528,19 +528,19 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
528528
const DataLayout data_layout = DataLayout::kNCHW) {
529529
const int batch_size = input.dims()[0];
530530
const int input_channels =
531-
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
531+
(data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
532532
const int input_height =
533-
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
533+
(data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
534534
const int input_width =
535-
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
535+
(data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
536536
const int output_channels =
537-
(data_layout == DataLayout::kNCHW ? output->dims()[1]
537+
(data_layout != DataLayout::kNHWC ? output->dims()[1]
538538
: output->dims()[3]);
539539
const int output_height =
540-
(data_layout == DataLayout::kNCHW ? output->dims()[2]
540+
(data_layout != DataLayout::kNHWC ? output->dims()[2]
541541
: output->dims()[1]);
542542
const int output_width =
543-
(data_layout == DataLayout::kNCHW ? output->dims()[3]
543+
(data_layout != DataLayout::kNHWC ? output->dims()[3]
544544
: output->dims()[2]);
545545
const int ksize_height = filter.dims()[2];
546546
const int ksize_width = filter.dims()[3];
@@ -614,19 +614,19 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
614614
const DataLayout data_layout = DataLayout::kNCHW) {
615615
const int batch_size = input.dims()[0];
616616
const int input_channels =
617-
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
617+
(data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
618618
const int input_height =
619-
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
619+
(data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
620620
const int input_width =
621-
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
621+
(data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
622622
const int output_channels =
623-
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
623+
(data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
624624
: output_grad.dims()[3]);
625625
const int output_height =
626-
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
626+
(data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
627627
: output_grad.dims()[1]);
628628
const int output_width =
629-
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
629+
(data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
630630
: output_grad.dims()[2]);
631631
const int ksize_height = filter.dims()[2];
632632
const int ksize_width = filter.dims()[3];
@@ -702,19 +702,19 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
702702
const DataLayout data_layout = DataLayout::kNCHW) {
703703
const int batch_size = input.dims()[0];
704704
const int input_channels =
705-
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
705+
(data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
706706
const int input_height =
707-
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
707+
(data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
708708
const int input_width =
709-
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
709+
(data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
710710
const int output_channels =
711-
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
711+
(data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
712712
: output_grad.dims()[3]);
713713
const int output_height =
714-
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
714+
(data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
715715
: output_grad.dims()[1]);
716716
const int output_width =
717-
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
717+
(data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
718718
: output_grad.dims()[2]);
719719
const int ksize_height = filter_grad->dims()[2];
720720
const int ksize_width = filter_grad->dims()[3];

paddle/fluid/operators/math/im2col.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
115115
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
116116
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
117117
int im_offset;
118-
if (data_layout == DataLayout::kNCHW) {
118+
if (data_layout != DataLayout::kNHWC) {
119119
im_offset =
120120
(c_im * im_height + im_row_idx) * im_width + im_col_idx;
121121
} else {

paddle/fluid/operators/math/im2col.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
3333
const int index =
3434
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
3535
if (index < num_outs) {
36-
int w_out = (data_layout == DataLayout::kNCHW
36+
int w_out = (data_layout != DataLayout::kNHWC
3737
? index % col_width
3838
: (index / input_channels) % col_width);
39-
int h_out = (data_layout == DataLayout::kNCHW
39+
int h_out = (data_layout != DataLayout::kNHWC
4040
? (index / col_width) % col_height
4141
: (index / input_channels / col_width) % col_height);
4242
int channel_in =
43-
(data_layout == DataLayout::kNCHW ? index / col_width / col_height
43+
(data_layout != DataLayout::kNHWC ? index / col_width / col_height
4444
: index % input_channels);
4545
int channel_out = channel_in * filter_height * filter_width;
4646
int h_in = h_out * stride_height - padding_height;
@@ -52,7 +52,7 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
5252
int rIdx = h_in + i * dilation_h;
5353
int cIdx = w_in + j * dilation_w;
5454
int im_idx;
55-
if (data_layout == DataLayout::kNCHW) {
55+
if (data_layout != DataLayout::kNHWC) {
5656
im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
5757
} else {
5858
im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
@@ -86,11 +86,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
8686
"The dimension of col should be 5.");
8787

8888
int im_channels =
89-
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
89+
(data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
9090
int im_height =
91-
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
91+
(data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
9292
int im_width =
93-
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
93+
(data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
9494
int filter_height = col->dims()[1];
9595
int filter_width = col->dims()[2];
9696
int col_height = col->dims()[3];
@@ -127,14 +127,14 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
127127

128128
if (index < n) {
129129
T val = 0;
130-
int w = (data_layout == DataLayout::kNCHW
130+
int w = (data_layout != DataLayout::kNHWC
131131
? index % im_width + padding_width
132132
: (index / input_channels) % im_width + padding_width);
133-
int h = (data_layout == DataLayout::kNCHW
133+
int h = (data_layout != DataLayout::kNHWC
134134
? (index / im_width) % im_height + padding_height
135135
: (index / input_channels / im_width) % im_height +
136136
padding_height);
137-
int c = (data_layout == DataLayout::kNCHW ? index / im_width / im_height
137+
int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height
138138
: index % input_channels);
139139

140140
// compute the start and end of the output
@@ -187,11 +187,11 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
187187
"The dimension of col should be 5.");
188188

189189
int im_channels =
190-
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]);
190+
(data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
191191
int im_height =
192-
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]);
192+
(data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
193193
int im_width =
194-
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]);
194+
(data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
195195
int filter_height = col.dims()[1];
196196
int filter_width = col.dims()[2];
197197
int col_height = col.dims()[3];

0 commit comments

Comments
 (0)