Skip to content

Commit 0e63746

Browse files
bingyanghuangluotao1
authored andcommitted
[cherry pick] Conv2d and Conv2d transpose MKL-DNN NHWC support (#21525)
1 parent 32a0eb5 commit 0e63746

File tree

12 files changed

+127
-43
lines changed

12 files changed

+127
-43
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
113113
PADDLE_THROW("wrong mkldnn type provided");
114114
}
115115
}
116-
#endif
117116

118117
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
119118
const OpKernelType& expected_kernel_type,
@@ -127,14 +126,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
127126
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
128127
"non-MKLDNN");
129128

130-
#ifdef PADDLE_WITH_MKLDNN
131129
innerTransDataLayoutFromMKLDNN(in_layout,
132130
paddle::platform::get_cur_paddle_data_layout(),
133131
in, out, place);
134-
#endif
135132
}
136133

137-
#ifdef PADDLE_WITH_MKLDNN
138134
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
139135
const Tensor& in, Tensor* out,
140136
platform::Place place) {

paddle/fluid/framework/data_layout_transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
6969
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
7070
const Tensor& in, Tensor* out,
7171
platform::Place place);
72-
#endif
7372

7473
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
7574
const OpKernelType& expected_kernel_type,
7675
const Tensor& in, Tensor* out);
76+
#endif
7777

7878
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
7979

paddle/fluid/framework/data_transform.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
4343

4444
// do layout transform
4545
if (NeedTransformLayout(lout, lin)) {
46+
#ifdef PADDLE_WITH_MKLDNN
4647
if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) {
4748
PADDLE_ENFORCE(
4849
!(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN),
4950
"No layout transform needed between two MKLDNN OPKernels");
5051

5152
if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) {
52-
#ifdef PADDLE_WITH_MKLDNN
5353
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
5454
// Just set layout/format. No real transform occur
5555

@@ -67,7 +67,6 @@ void TransformData(const OpKernelType &expected_kernel_type,
6767
}
6868
out.set_layout(DataLayout::kMKLDNN);
6969
out.set_format(out_format);
70-
#endif
7170
} else {
7271
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
7372
// Do transform via MKLDNN lib
@@ -78,6 +77,10 @@ void TransformData(const OpKernelType &expected_kernel_type,
7877
// Case3 - transfrom between Non-MKLDNN OPKernels
7978
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
8079
}
80+
#else
81+
// Case3 - transfrom between Non-MKLDNN OPKernels
82+
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
83+
#endif
8184
transformed = true;
8285
PassTensorData(&out, &in);
8386
}

paddle/fluid/operators/conv_op.cc

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
4848
int groups = ctx->Attrs().Get<int>("groups");
4949
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
5050
const std::string data_format = ctx->Attrs().Get<std::string>("data_format");
51-
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
51+
52+
// MKL-DNN Kernels are using NCHW order of dims description
53+
// so we ignore data_format consideration for MKL-DNN kernel
54+
const bool channel_last = (this->IsMKLDNNType() == false) &&
55+
(data_format == "NHWC" || data_format == "NDHWC");
5256

5357
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
5458
"ShapeError: Conv input should be 4-D or 5-D tensor. But "
@@ -148,15 +152,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
148152
#ifdef PADDLE_WITH_MKLDNN
149153
if (library == framework::LibraryType::kPlain &&
150154
platform::CanMKLDNNBeUsed(ctx)) {
151-
// TODO(jczaja): Add support for NHWC
152-
const std::string data_format = ctx.Attr<std::string>("data_format");
153-
PADDLE_ENFORCE_NE(data_format, "NHWC",
154-
platform::errors::Unimplemented(
155-
"Conv MKLDNN does not support NHWC data format yet"));
156-
PADDLE_ENFORCE_NE(
157-
data_format, "NDHWC",
158-
platform::errors::Unimplemented(
159-
"Conv MKLDNN does not support NDHWC data format yet"));
160155
library = framework::LibraryType::kMKLDNN;
161156
layout = framework::DataLayout::kMKLDNN;
162157
customized_type_value =
@@ -194,6 +189,32 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
194189
return type;
195190
}
196191

192+
framework::OpKernelType ConvOp::GetKernelTypeForVar(
193+
const std::string& var_name, const Tensor& tensor,
194+
const framework::OpKernelType& expected_kernel_type) const {
195+
#ifdef PADDLE_WITH_MKLDNN
196+
// Only input require reshaping, weights and
197+
// bias are having shape in NCHW order
198+
if ((var_name == "Input") &&
199+
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
200+
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
201+
auto attrs = Attrs();
202+
auto ar = paddle::framework::AttrReader(attrs);
203+
const std::string data_format = ar.Get<std::string>("data_format");
204+
auto dl = framework::StringToDataLayout(data_format);
205+
// Some models may have intentionally set "AnyLayout" for pool
206+
// op. Treat this as NCHW (default data_format value)
207+
if (dl != framework::DataLayout::kAnyLayout) {
208+
return framework::OpKernelType(
209+
expected_kernel_type.data_type_, tensor.place(),
210+
framework::StringToDataLayout(data_format));
211+
}
212+
}
213+
#endif
214+
return framework::OpKernelType(expected_kernel_type.data_type_,
215+
tensor.place(), tensor.layout());
216+
}
217+
197218
void Conv2DOpMaker::Make() {
198219
AddAttr<bool>("is_test",
199220
"(bool, default false) Set to true for inference only, false "

paddle/fluid/operators/conv_op.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ class ConvOp : public framework::OperatorWithKernel {
225225
protected:
226226
framework::OpKernelType GetExpectedKernelType(
227227
const framework::ExecutionContext& ctx) const override;
228+
229+
framework::OpKernelType GetKernelTypeForVar(
230+
const std::string& var_name, const Tensor& tensor,
231+
const framework::OpKernelType& expected_kernel_type) const override;
228232
};
229233

230234
class ConvOpGrad : public framework::OperatorWithKernel {

paddle/fluid/operators/conv_transpose_op.cc

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
4646
int groups = ctx->Attrs().Get<int>("groups");
4747
std::string padding_algorithm =
4848
ctx->Attrs().Get<std::string>("padding_algorithm");
49-
const DataLayout data_layout = framework::StringToDataLayout(
50-
ctx->Attrs().Get<std::string>("data_format"));
49+
const std::string data_layout_str =
50+
ctx->Attrs().Get<std::string>("data_format");
51+
const DataLayout data_layout =
52+
this->IsMKLDNNType() ? DataLayout::kNCHW
53+
: framework::StringToDataLayout(data_layout_str);
5154

5255
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
5356
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
@@ -127,11 +130,6 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
127130
#ifdef PADDLE_WITH_MKLDNN
128131
if (library_ == framework::LibraryType::kPlain &&
129132
platform::CanMKLDNNBeUsed(ctx)) {
130-
// TODO(jczaja): Add support for NHWC
131-
const std::string data_format = ctx.Attr<std::string>("data_format");
132-
PADDLE_ENFORCE_NE(
133-
data_format, "NHWC",
134-
"Conv Transpose MKLDNN does not support NHWC data format yet");
135133
library_ = framework::LibraryType::kMKLDNN;
136134
layout_ = framework::DataLayout::kMKLDNN;
137135
}
@@ -142,6 +140,32 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
142140
layout_, library_);
143141
}
144142

143+
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
144+
const std::string& var_name, const Tensor& tensor,
145+
const framework::OpKernelType& expected_kernel_type) const {
146+
#ifdef PADDLE_WITH_MKLDNN
147+
// Only input require reshaping, weights and
148+
// bias are having shape in NCHW order
149+
if ((var_name == "Input") &&
150+
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
151+
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
152+
auto attrs = Attrs();
153+
auto ar = paddle::framework::AttrReader(attrs);
154+
const std::string data_format = ar.Get<std::string>("data_format");
155+
auto dl = framework::StringToDataLayout(data_format);
156+
// Some models may have intentionally set "AnyLayout" for pool
157+
// op. Treat this as NCHW (default data_format value)
158+
if (dl != framework::DataLayout::kAnyLayout) {
159+
return framework::OpKernelType(
160+
expected_kernel_type.data_type_, tensor.place(),
161+
framework::StringToDataLayout(data_format));
162+
}
163+
}
164+
#endif
165+
return framework::OpKernelType(expected_kernel_type.data_type_,
166+
tensor.place(), tensor.layout());
167+
}
168+
145169
void Conv2DTransposeOpMaker::Make() {
146170
AddAttr<bool>("is_test",
147171
"(bool, default false) Set to true for inference only, false "

paddle/fluid/operators/conv_transpose_op.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ class ConvTransposeOp : public framework::OperatorWithKernel {
9898
protected:
9999
framework::OpKernelType GetExpectedKernelType(
100100
const framework::ExecutionContext& ctx) const override;
101+
102+
framework::OpKernelType GetKernelTypeForVar(
103+
const std::string& var_name, const Tensor& tensor,
104+
const framework::OpKernelType& expected_kernel_type) const override;
101105
};
102106

103107
class ConvTransposeOpGrad : public framework::OperatorWithKernel {

paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
220220
* ('any') which lets a primitive (convolution in this case) choose
221221
* the memory format preferred for best performance
222222
*/
223+
// TODO(jczaja): This is workaround to make grad op UT's numerical
224+
// gradient computation proper as this op is called directly without
225+
// fetch op following it , so numercial grad is computed (in python)
226+
// using block formats which will give wrong results
223227
std::string data_format = ctx.Attr<std::string>("data_format");
224228
auto chosen_memory_format =
225-
platform::data_format_to_memory_format(data_format);
229+
is_test ? MKLDNNMemoryFormat::any
230+
: platform::data_format_to_memory_format(data_format);
226231

227232
weights_format = MKLDNNMemoryFormat::any;
228233
// Check the format for user's special output
@@ -519,9 +524,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
519524
* ('any') which lets a primitive (convolution in this case) choose
520525
* the memory format preferred for best performance
521526
*/
522-
std::string data_format = ctx.Attr<std::string>("data_format");
523-
auto chosen_memory_format =
524-
platform::data_format_to_memory_format(data_format);
527+
auto chosen_memory_format = MKLDNNMemoryFormat::any;
525528

526529
std::vector<int> bias_tz;
527530

@@ -772,18 +775,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
772775
* ('any') which lets a primitive (conv backward in this case) choose
773776
* the memory format preferred for best performance
774777
*/
775-
std::string data_format = ctx.Attr<std::string>("data_format");
776-
auto chosen_memory_format =
777-
platform::data_format_to_memory_format(data_format);
778-
778+
auto chosen_memory_format = MKLDNNMemoryFormat::any;
779779
weights_format = MKLDNNMemoryFormat::any;
780-
// Check the format for user's special output
781-
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
782-
if (is_conv3d) {
783-
chosen_memory_format =
784-
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
785-
}
786-
}
787780

788781
auto src_md = platform::MKLDNNMemDesc(
789782
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);

paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
155155
* ('any') which lets a primitive (convolution in this case) choose
156156
* the memory format preferred for best performance
157157
*/
158-
std::string data_format = ctx.Attr<std::string>("data_format");
159-
auto chosen_memory_format =
160-
platform::data_format_to_memory_format(data_format);
158+
auto chosen_memory_format = MKLDNNMemoryFormat::any;
161159
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
162160
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
163161
float fuse_beta = ctx.Attr<float>("fuse_beta");

python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def setUp(self):
3535
self.exhaustive_search = False
3636
self.use_cuda = False
3737
self.use_mkldnn = False
38-
self.data_format = "AnyLayout"
38+
self.data_format = "NCHW"
3939
self.weighttype = np.float32
4040
self.use_mkldnn = True
4141
self.init_group()

0 commit comments

Comments
 (0)