Skip to content

Commit 1f598df

Browse files
bingyanghuangluotao1
authored andcommitted
cherry-pick MKL-DNN NHWC FWD support fix (#21593)
1 parent f83254d commit 1f598df

File tree

5 files changed

+77
-3
lines changed

5 files changed

+77
-3
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
7979
x_dims, x_dims.size());
8080

8181
const int64_t C =
82-
(data_layout == DataLayout::kNCHW ? x_dims[1]
83-
: x_dims[x_dims.size() - 1]);
82+
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
83+
? x_dims[1]
84+
: x_dims[x_dims.size() - 1]);
8485

8586
auto scale_dim = ctx->GetInputDim("Scale");
8687
auto bias_dim = ctx->GetInputDim("Bias");
@@ -154,6 +155,32 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
154155
library);
155156
}
156157

158+
framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
159+
const std::string &var_name, const Tensor &tensor,
160+
const framework::OpKernelType &expected_kernel_type) const {
161+
#ifdef PADDLE_WITH_MKLDNN
162+
// Only input require reshaping, weights and
163+
// bias are having shape in NCHW order
164+
if ((var_name == "X") &&
165+
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
166+
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
167+
auto attrs = Attrs();
168+
auto ar = paddle::framework::AttrReader(attrs);
169+
const std::string data_layout = ar.Get<std::string>("data_layout");
170+
auto dl = framework::StringToDataLayout(data_layout);
171+
// Some models may have intentionally set "AnyLayout" for pool
172+
// op. Treat this as NCHW (default data_format value)
173+
if (dl != framework::DataLayout::kAnyLayout) {
174+
return framework::OpKernelType(
175+
expected_kernel_type.data_type_, tensor.place(),
176+
framework::StringToDataLayout(data_layout));
177+
}
178+
}
179+
#endif
180+
return framework::OpKernelType(expected_kernel_type.data_type_,
181+
tensor.place(), tensor.layout());
182+
}
183+
157184
void BatchNormOpMaker::Make() {
158185
AddAttr<bool>("is_test",
159186
"(bool, default false) Set to true for inference only, false "
@@ -446,6 +473,12 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
446473
#ifdef PADDLE_WITH_MKLDNN
447474
if (library == framework::LibraryType::kPlain &&
448475
platform::CanMKLDNNBeUsed(ctx)) {
476+
// TODO(jczaja): Add support for NHWC
477+
const std::string data_layout = ctx.Attr<std::string>("data_layout");
478+
PADDLE_ENFORCE_NE(
479+
data_layout, "NHWC",
480+
platform::errors::Unimplemented(
481+
"Batch Norm MKLDNN grad does not support NHWC data format yet"));
449482
library = framework::LibraryType::kMKLDNN;
450483
layout = framework::DataLayout::kMKLDNN;
451484
}

paddle/fluid/operators/batch_norm_op.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class BatchNormOp : public framework::OperatorWithKernel {
4747
protected:
4848
framework::OpKernelType GetExpectedKernelType(
4949
const framework::ExecutionContext &ctx) const override;
50+
51+
framework::OpKernelType GetKernelTypeForVar(
52+
const std::string &var_name, const Tensor &tensor,
53+
const framework::OpKernelType &expected_kernel_type) const override;
5054
};
5155

5256
class BatchNormGradOp : public framework::OperatorWithKernel {

paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
775775
* ('any') which lets a primitive (conv backward in this case) choose
776776
* the memory format preferred for best performance
777777
*/
778-
auto chosen_memory_format = MKLDNNMemoryFormat::any;
778+
779+
// TODO(jczaja): Once GRAD NHWC is working then format 'any'
780+
// should be used exclusively. But till forward pass enforce
781+
// NCHW for training we need to have NCHW here as well
782+
// to avoid performance degradation in relu_grad and pool2d_grad
783+
std::string data_format = ctx.Attr<std::string>("data_format");
784+
auto chosen_memory_format =
785+
platform::data_format_to_memory_format(data_format);
786+
779787
weights_format = MKLDNNMemoryFormat::any;
788+
// Check the format for user's special output
789+
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
790+
if (is_conv3d) {
791+
chosen_memory_format =
792+
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
793+
}
794+
}
780795

781796
auto src_md = platform::MKLDNNMemDesc(
782797
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);

python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def test_check_output(self):
8484
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
8585

8686

87+
class TestMKLDNNBatchNormOpInference_NHWC(TestMKLDNNBatchNormOpInference):
88+
def test_check_output(self):
89+
place = core.CPUPlace()
90+
data_format = "NHWC"
91+
self.check_with_place(place, data_format, self.dtype, [2, 4, 5, 3])
92+
93+
8794
class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
8895
def init_kernel_type(self):
8996
self.use_mkldnn = True

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,21 @@ def check_with_place(self, place, data_layout, dtype, shape):
259259

260260
batch_norm_op.run(scope, place)
261261

262+
# When op is called without Executor then
263+
# MKL-DNN Tensor is returned. For NHWC data layout
264+
# dims will be in NCHW order as it is MKL-DNN way
265+
# of memory descripting. So we need to convert NCHW
266+
# dims into NHWC.
267+
if data_layout == "NHWC" and self.use_mkldnn == True:
268+
# Create executor to have MKL-DNN cache
269+
# cleared after NHWC unit test
270+
place = core.CPUPlace()
271+
exe = fluid.Executor(place)
272+
dims = y_tensor.shape()
273+
c = dims.pop(1)
274+
dims.append(c)
275+
y_tensor._set_dims(dims)
276+
262277
# check inference result
263278
self.__assert_close(
264279
y_tensor,

0 commit comments

Comments
 (0)