Skip to content

Commit ccb508d

Browse files
bingyanghuangluotao1
authored andcommitted
cherry-pick LRN and Pool2d (FWD) NHWC support (#21476)
1 parent 9ab738a commit ccb508d

File tree

16 files changed

+111
-39
lines changed

16 files changed

+111
-39
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,17 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
127127
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
128128
"non-MKLDNN");
129129

130-
innerTransDataLayoutFromMKLDNN(in_layout, out_layout, in, out, place);
130+
#ifdef PADDLE_WITH_MKLDNN
131+
innerTransDataLayoutFromMKLDNN(in_layout,
132+
paddle::platform::get_cur_paddle_data_layout(),
133+
in, out, place);
134+
#endif
131135
}
132136

137+
#ifdef PADDLE_WITH_MKLDNN
133138
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
134139
const Tensor& in, Tensor* out,
135140
platform::Place place) {
136-
#ifdef PADDLE_WITH_MKLDNN
137141
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::format_undef,
138142
"Input tensor should have specified memory format");
139143
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::any,
@@ -181,11 +185,17 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
181185
} else {
182186
out->ShareDataWith(in);
183187
}
188+
// For exepected NHWC data format we need to reshape the Output tensor
189+
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
190+
if (out_layout == DataLayout::kNHWC) {
191+
std::rotate(out_tz.begin() + 1, out_tz.begin() + 2, out_tz.end());
192+
out->Resize(framework::make_ddim(out_tz));
193+
}
184194
out->set_layout(out_layout);
185195
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
186196
out->set_format(MKLDNNMemoryFormat::format_undef);
187-
#endif
188197
}
198+
#endif
189199

190200
} // namespace framework
191201
} // namespace paddle

paddle/fluid/framework/data_layout_transform.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,15 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
6666
return MKLDNNDataType::data_undef;
6767
}
6868

69+
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
70+
const Tensor& in, Tensor* out,
71+
platform::Place place);
6972
#endif
7073

7174
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
7275
const OpKernelType& expected_kernel_type,
7376
const Tensor& in, Tensor* out);
7477

75-
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
76-
const Tensor& in, Tensor* out,
77-
platform::Place place);
78-
7978
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
8079

8180
void TransDataLayout(const OpKernelType& kernel_type_for_var,

paddle/fluid/framework/data_transform.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/fluid/framework/data_type_transform.h"
2020

2121
#ifdef PADDLE_WITH_MKLDNN
22+
#include <algorithm>
2223
#include "paddle/fluid/platform/mkldnn_helper.h"
2324
#endif
2425

@@ -54,8 +55,16 @@ void TransformData(const OpKernelType &expected_kernel_type,
5455

5556
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
5657
ToMKLDNNFormat(lin));
57-
5858
out.ShareDataWith(input_tensor);
59+
// For NHWC data we need reshape of tensors as MKL-DNN
60+
// is expecting NHWC dims description order
61+
if (lin == DataLayout::kNHWC) {
62+
auto nchw_dims = paddle::framework::vectorize<int>(out.dims());
63+
std::rotate(nchw_dims.begin() + 1, nchw_dims.end() - 1,
64+
nchw_dims.end());
65+
out.Resize(framework::make_ddim(nchw_dims));
66+
paddle::platform::set_cur_paddle_data_layout(lin);
67+
}
5968
out.set_layout(DataLayout::kMKLDNN);
6069
out.set_format(out_format);
6170
#endif

paddle/fluid/framework/executor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ Executor::~Executor() {
103103
platform::MKLDNNDeviceContext* dev_ctx =
104104
(platform::MKLDNNDeviceContext*)pool.Get(place_);
105105
dev_ctx->ResetBlobMap();
106+
platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW);
106107
}
107108
#endif
108109
}

paddle/fluid/framework/operator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,11 @@ class OperatorWithKernel : public OperatorBase {
443443
return g_all_op_kernels;
444444
}
445445

446+
bool IsMKLDNNType() const {
447+
return ((this->kernel_type_) && (this->kernel_type_->data_layout_ ==
448+
framework::DataLayout::kMKLDNN));
449+
}
450+
446451
bool SupportGPU() const override {
447452
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
448453
return std::any_of(op_kernels.begin(), op_kernels.end(),

paddle/fluid/operators/controlflow/fetch_op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,20 @@ class FetchOp : public framework::OperatorBase {
5656
// FIXME(yuyang18): Should we assume the fetch operator always generate
5757
// CPU outputs?
5858
if (src_item.IsInitialized() && src_item.numel() > 0) {
59+
#ifdef PADDLE_WITH_MKLDNN
5960
// Conversion from MKL-DNN to Paddle
6061
if (src_item.layout() == framework::DataLayout::kMKLDNN) {
6162
framework::Tensor out;
6263
framework::innerTransDataLayoutFromMKLDNN(
63-
src_item.layout(), framework::DataLayout::kNCHW, src_item, &out,
64-
platform::CPUPlace());
64+
src_item.layout(), paddle::platform::get_cur_paddle_data_layout(),
65+
src_item, &out, platform::CPUPlace());
6566
TensorCopySync(out, platform::CPUPlace(), &dst_item);
6667
} else {
6768
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
6869
}
70+
#else
71+
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
72+
#endif
6973
} else {
7074
// Not copy, if the src tensor is empty.
7175
dst_item.clear();

paddle/fluid/operators/lrn_op.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,6 @@ class LRNOp : public framework::OperatorWithKernel {
193193
#ifdef PADDLE_WITH_MKLDNN
194194
if (library_ == framework::LibraryType::kPlain &&
195195
platform::CanMKLDNNBeUsed(ctx)) {
196-
// TODO(jczaja): Add support for NHWC
197-
const std::string data_format = ctx.Attr<std::string>("data_format");
198-
PADDLE_ENFORCE_NE(
199-
data_format, "NHWC",
200-
platform::errors::Unimplemented(
201-
"LRN MKLDNN does not support NHWC data format yet"));
202196
library_ = framework::LibraryType::kMKLDNN;
203197
layout_ = framework::DataLayout::kMKLDNN;
204198
}
@@ -207,6 +201,28 @@ class LRNOp : public framework::OperatorWithKernel {
207201
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
208202
layout_, library_);
209203
}
204+
205+
framework::OpKernelType GetKernelTypeForVar(
206+
const std::string& var_name, const Tensor& tensor,
207+
const framework::OpKernelType& expected_kernel_type) const override {
208+
#ifdef PADDLE_WITH_MKLDNN
209+
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
210+
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
211+
auto attrs = Attrs();
212+
auto ar = paddle::framework::AttrReader(attrs);
213+
const std::string data_format = ar.Get<std::string>("data_format");
214+
auto dl = framework::StringToDataLayout(data_format);
215+
// Some models may have intentionally set "AnyLayout" for pool
216+
// op. Treat this as NCHW (default data_format value)
217+
if (dl != framework::DataLayout::kAnyLayout) {
218+
return framework::OpKernelType(expected_kernel_type.data_type_,
219+
tensor.place(), dl);
220+
}
221+
}
222+
#endif
223+
return framework::OpKernelType(expected_kernel_type.data_type_,
224+
tensor.place(), tensor.layout());
225+
}
210226
};
211227

212228
template <typename T>

paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
101101
pipeline.push_back(*reorder_p);
102102
stream(stream::kind::eager).submit(pipeline).wait();
103103

104+
output->set_layout(DataLayout::kMKLDNN);
104105
output->set_format(GetMKLDNNFormat(*dst_memory));
105106
}
106107
};

paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
6262
std::shared_ptr<mkldnn::lrn_forward> lrn_p;
6363
if (is_test == false) {
6464
workspace_memory = handler.AcquireWorkspaceMemory(mid);
65+
mid->set_layout(framework::DataLayout::kMKLDNN);
66+
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory));
6567
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory,
6668
*dst_memory);
6769
} else {

paddle/fluid/operators/pool_op.cc

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
7171
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
7272
"Strides size and pooling size should be the same.");
7373

74-
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
74+
// MKL-DNN Kernels are using NCHW order of dims description
75+
// so we ignore data_format consideration for MKL-DNN kernel
76+
const bool channel_last = (this->IsMKLDNNType() == false) &&
77+
(data_format == "NHWC" || data_format == "NDHWC");
7578

7679
// update paddings if "SAME" or global_pooling
7780
framework::DDim data_dims;
@@ -129,12 +132,6 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
129132
#ifdef PADDLE_WITH_MKLDNN
130133
if (library_ == framework::LibraryType::kPlain &&
131134
platform::CanMKLDNNBeUsed(ctx)) {
132-
// TODO(jczaja): Add support for NHWC
133-
const std::string data_format = ctx.Attr<std::string>("data_format");
134-
PADDLE_ENFORCE_NE(
135-
data_format, "NHWC",
136-
platform::errors::Unimplemented(
137-
"Pool MKLDNN grad does not support NHWC data format yet"));
138135
library_ = framework::LibraryType::kMKLDNN;
139136
layout_ = framework::DataLayout::kMKLDNN;
140137
}
@@ -145,6 +142,28 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
145142
layout_, library_);
146143
}
147144

145+
framework::OpKernelType PoolOp::GetKernelTypeForVar(
146+
const std::string& var_name, const Tensor& tensor,
147+
const framework::OpKernelType& expected_kernel_type) const {
148+
#ifdef PADDLE_WITH_MKLDNN
149+
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
150+
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
151+
auto attrs = Attrs();
152+
auto ar = paddle::framework::AttrReader(attrs);
153+
const std::string data_format = ar.Get<std::string>("data_format");
154+
auto dl = framework::StringToDataLayout(data_format);
155+
// Some models may have intentionally set "AnyLayout" for pool
156+
// op. Treat this as NCHW (default data_format value)
157+
if (dl != framework::DataLayout::kAnyLayout) {
158+
return framework::OpKernelType(expected_kernel_type.data_type_,
159+
tensor.place(), dl);
160+
}
161+
}
162+
#endif
163+
return framework::OpKernelType(expected_kernel_type.data_type_,
164+
tensor.place(), tensor.layout());
165+
}
166+
148167
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
149168
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null.");
150169
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,

0 commit comments

Comments
 (0)