Skip to content

Commit e26f51c

Browse files
Tomasz Patejkoluotao1
authored andcommitted
MKLDNN elementwis_add with default broadcast operations (#11544)
* elementwise_add with bcast: Brian's implementation by Brian added, with default bcasts * elementwise_add with bcast: GetExpectedKernelType added to elementwise_op * elementwise_add with bcast: use_mkldnn attribute added * elementwise_add with bcast: changes after review and some formatting * elementwise_add with bcast: changes after style check * elementwise_add with bcast: changes after style check cont. * elementwise_add with bcast: MKLDNN unittests added * elementwise_add with bcast: original unittests with use_mkldnn flag * elementwise_add with bcast: handling of MKLDNN format corrected * elementwise_add with bcast: setting MKLDNN format turned into lambda * elementwise_add with bcast: MKDNN format setting turned into separate function * elementwise_add with bcast: condition for choosing MKLDNN simplified * elementwise_add with bcast: fix for MKLDNN format set incorrectly in bcasts * elementwise_add with bcast: changes in unittests for broadcasts * elementwise_add with bcast: fixes in unittests regarding dimensions * elementwise_add with bcast: bring back correct format setting in mklml grad path * elementwise_add with bcast: fixed compilation error
1 parent 67ab324 commit e26f51c

File tree

7 files changed

+376
-6
lines changed

7 files changed

+376
-6
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
147147
"Input tensor type is not supported: ", in.type().name());
148148
memory::data_type out_type = in_type;
149149

150-
memory::format in_format =
151-
in_tz.size() == 2 ? memory::format::nc : in.format();
152-
memory::format out_format =
153-
out_tz.size() == 2 ? memory::format::nc : ToMKLDNNFormat(out_layout);
150+
auto in_format = MKLDNNFormatForSize(in_tz.size(), in.format());
151+
auto out_format =
152+
MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
154153

155154
void* in_data = GetDataFromTensor(in, in_type);
156155

paddle/fluid/framework/data_layout_transform.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
6161
if (iter != dict.end()) return iter->second;
6262
return MKLDNNDataType::data_undef;
6363
}
64+
65+
inline MKLDNNFormat MKLDNNFormatForSize(size_t dims_size,
66+
MKLDNNFormat default_format) {
67+
return (dims_size == 1
68+
? mkldnn::memory::format::x
69+
: dims_size == 2 ? mkldnn::memory::format::nc : default_format);
70+
}
6471
#endif
6572

6673
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,

paddle/fluid/framework/data_transform.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,13 @@ void DataTransform(const OpKernelType& expected_kernel_type,
4747
#ifdef PADDLE_WITH_MKLDNN
4848
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
4949
// Just set layout/format. No real transform occur
50+
51+
auto out_format =
52+
MKLDNNFormatForSize(in.dims().size(), ToMKLDNNFormat(lin));
53+
5054
out.ShareDataWith(input_tensor);
5155
out.set_layout(DataLayout::kMKLDNN);
52-
out.set_format(ToMKLDNNFormat(lin));
56+
out.set_format(out_format);
5357
#endif
5458
} else {
5559
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/memory/memcpy.h"
16+
#include "paddle/fluid/operators/elementwise_add_op.h"
17+
#include "paddle/fluid/operators/elementwise_op_function.h"
18+
19+
#include "paddle/fluid/platform/mkldnn_helper.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using framework::DataLayout;
25+
using framework::Tensor;
26+
using mkldnn::memory;
27+
using mkldnn::reorder;
28+
using mkldnn::primitive;
29+
using mkldnn::stream;
30+
using mkldnn::sum;
31+
32+
template <typename T>
33+
class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
34+
public:
35+
void Compute(const framework::ExecutionContext& ctx) const override {
36+
auto& dev_ctx =
37+
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
38+
const auto& mkldnn_engine = dev_ctx.GetEngine();
39+
40+
auto* x = ctx.Input<Tensor>("X");
41+
auto* y = ctx.Input<Tensor>("Y");
42+
auto* z = ctx.Output<Tensor>("Out");
43+
const T* x_data = x->data<T>();
44+
const T* y_data = y->data<T>();
45+
T* z_data = z->mutable_data<T>(ctx.GetPlace());
46+
47+
int axis = ctx.Attr<int>("axis");
48+
49+
auto x_dims = x->dims();
50+
auto y_dims = y->dims();
51+
auto z_dims = z->dims();
52+
53+
// Execute default elementwise_add operator when
54+
// broadcast operations need to performed.
55+
if (x_dims != y_dims) {
56+
auto sum_func = [](T a, T b) -> T { return a + b; };
57+
58+
TransformFunctor<decltype(sum_func), T,
59+
paddle::platform::CPUDeviceContext, T>
60+
functor(
61+
x, y, z,
62+
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
63+
sum_func);
64+
65+
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
66+
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
67+
"Axis should be in range [0, x_dims)");
68+
69+
trim_trailing_singular_dims(&y_dims);
70+
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
71+
72+
int pre, n, post;
73+
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
74+
75+
if (post == 1) {
76+
functor.RunRowWise(n, pre);
77+
} else {
78+
functor.RunMidWise(n, pre, post);
79+
}
80+
z->set_layout(DataLayout::kMKLDNN);
81+
z->set_format(x->format());
82+
} else {
83+
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
84+
x->format() != memory::format::format_undef,
85+
"Wrong layout/format set for X tensor");
86+
PADDLE_ENFORCE(y->layout() == DataLayout::kMKLDNN &&
87+
y->format() != memory::format::format_undef,
88+
"Wrong layout/format set for X tensor");
89+
90+
std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
91+
std::vector<int> src_y_tz = framework::vectorize2int(y_dims);
92+
std::vector<int> dst_tz = framework::vectorize2int(z_dims);
93+
94+
std::vector<memory::primitive_desc> srcs_pd;
95+
std::vector<memory> srcs;
96+
std::vector<float> scales = {1.0f, 1.0f};
97+
98+
auto src_x_pd = memory::primitive_desc(
99+
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
100+
auto src_y_pd = memory::primitive_desc(
101+
{{src_y_tz}, memory::data_type::f32, y->format()}, mkldnn_engine);
102+
auto src_x_memory =
103+
memory(src_x_pd, paddle::platform::to_void_cast(x_data));
104+
auto src_y_memory =
105+
memory(src_y_pd, paddle::platform::to_void_cast(y_data));
106+
107+
srcs_pd.push_back(src_x_pd);
108+
srcs_pd.push_back(src_y_pd);
109+
srcs.push_back(src_x_memory);
110+
srcs.push_back(src_y_memory);
111+
112+
auto dst_md =
113+
memory::desc({dst_tz}, memory::data_type::f32, memory::format::any);
114+
115+
// create primitive descriptor for sum
116+
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
117+
118+
// create mkldnn memory for dst
119+
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
120+
121+
std::vector<primitive::at> inputs;
122+
inputs.push_back(srcs[0]);
123+
inputs.push_back(srcs[1]);
124+
125+
// create sum primitive
126+
auto sum_prim = sum(sum_pd, inputs, dst_memory);
127+
128+
std::vector<primitive> pipeline;
129+
pipeline.push_back(sum_prim);
130+
stream(stream::kind::eager).submit(pipeline).wait();
131+
132+
z->set_layout(DataLayout::kMKLDNN);
133+
z->set_format(
134+
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
135+
}
136+
}
137+
};
138+
139+
template <typename T>
140+
class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
141+
public:
142+
void Compute(const framework::ExecutionContext& ctx) const override {
143+
using Tensor = framework::Tensor;
144+
145+
auto* x = ctx.Input<Tensor>("X");
146+
auto* y = ctx.Input<Tensor>("Y");
147+
auto* out = ctx.Input<Tensor>("Out");
148+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
149+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
150+
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
151+
int axis = ctx.Attr<int>("axis");
152+
153+
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
154+
in->set_layout(DataLayout::kMKLDNN);
155+
in->set_format(out->format());
156+
};
157+
158+
if (x->dims() == y->dims()) {
159+
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
160+
if (dx) {
161+
blas.VCOPY(dout->numel(), dout->data<T>(),
162+
dx->mutable_data<T>(ctx.GetPlace()));
163+
set_mkldnn_format(dx, dout);
164+
}
165+
166+
if (dy) {
167+
blas.VCOPY(dout->numel(), dout->data<T>(),
168+
dy->mutable_data<T>(ctx.GetPlace()));
169+
set_mkldnn_format(dy, dout);
170+
}
171+
} else {
172+
// Execute default kernel when broadcast is needed
173+
ElemwiseGradCompute<paddle::platform::CPUDeviceContext, T,
174+
IdentityGrad<T>, IdentityGrad<T>>(
175+
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
176+
IdentityGrad<T>());
177+
}
178+
}
179+
};
180+
181+
} // namespace operators
182+
} // namespace paddle
183+
184+
namespace ops = paddle::operators;
185+
186+
REGISTER_OP_KERNEL(elementwise_add, MKLDNN, ::paddle::platform::CPUPlace,
187+
ops::EltwiseAddMKLDNNKernel<float>)
188+
189+
REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
190+
ops::EltwiseAddMKLDNNGradKernel<float>)

paddle/fluid/operators/elementwise_op.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <string>
17+
#include "paddle/fluid/framework/data_layout.h"
1718
#include "paddle/fluid/framework/op_registry.h"
1819
#include "paddle/fluid/framework/operator.h"
20+
#ifdef PADDLE_WITH_MKLDNN
21+
#include "paddle/fluid/platform/mkldnn_helper.h"
22+
#endif
1923

2024
namespace paddle {
2125
namespace operators {
@@ -40,6 +44,21 @@ class ElementwiseOp : public framework::OperatorWithKernel {
4044
ctx->SetOutputDim("Out", x_dim);
4145
ctx->ShareLoD("X", /*->*/ "Out");
4246
}
47+
48+
framework::OpKernelType GetExpectedKernelType(
49+
const framework::ExecutionContext& ctx) const override {
50+
auto input_data_type =
51+
framework::ToDataType(ctx.Input<Tensor>("X")->type());
52+
53+
#ifdef PADDLE_WITH_MKLDNN
54+
if (platform::CanMKLDNNBeUsed(ctx)) {
55+
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
56+
framework::DataLayout::kMKLDNN,
57+
framework::LibraryType::kMKLDNN);
58+
}
59+
#endif
60+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
61+
}
4362
};
4463

4564
class ElementwiseOpInferVarType : public framework::VarTypeInference {
@@ -65,6 +84,8 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
6584
"for broadcasting Y onto X.")
6685
.SetDefault(-1)
6786
.EqualGreaterThan(-1);
87+
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
88+
.SetDefault(false);
6889
AddComment(string::Sprintf(R"DOC(
6990
Limited Elementwise %s Operator
7091
@@ -138,6 +159,21 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
138159
ctx->SetOutputDim(y_grad_name, y_dims);
139160
}
140161
}
162+
163+
framework::OpKernelType GetExpectedKernelType(
164+
const framework::ExecutionContext& ctx) const override {
165+
auto input_data_type =
166+
framework::ToDataType(ctx.Input<Tensor>("X")->type());
167+
168+
#ifdef PADDLE_WITH_MKLDNN
169+
if (platform::CanMKLDNNBeUsed(ctx)) {
170+
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
171+
framework::DataLayout::kMKLDNN,
172+
framework::LibraryType::kMKLDNN);
173+
}
174+
#endif
175+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
176+
}
141177
};
142178
} // namespace operators
143179
} // namespace paddle

0 commit comments

Comments
 (0)