Skip to content

Commit 64045c2

Browse files
authored
Merge pull request #11102 from mozga-intel/mozga-intel/Sum_mkldnn_layout
MKLDNN layout: Support for sum operator
2 parents 2074d36 + b88cda8 commit 64045c2

File tree

13 files changed

+411
-102
lines changed

13 files changed

+411
-102
lines changed

paddle/fluid/operators/parallel_do_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ class ParallelDoGradOp : public framework::OperatorBase {
295295

296296
auto sum_op = framework::OpRegistry::CreateOp(
297297
"sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}},
298-
framework::AttributeMap{});
298+
framework::AttributeMap{{"use_mkldnn", {false}}});
299299
VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]);
300300
sum_op->Run(*sub_scopes[0], places[0]);
301301
WaitOnPlace(places[0]);

paddle/fluid/operators/recurrent_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,8 @@ class RecurrentGradOp : public RecurrentBase {
429429

430430
auto sum_op = framework::OpRegistry::CreateOp(
431431
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
432-
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
432+
{{"Out", {pg_names[param_id]}}},
433+
framework::AttributeMap{{"use_mkldnn", {false}}});
433434
sum_op->Run(cur_scope, place);
434435

435436
cur_scope.Rename(new_inside_name, inside_grad_name);
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/*Licensed under the Apache License, Version 2.0(the "License");
16+
you may not use this file except in compliance with the License.
17+
You may obtain a copy of the License at
18+
19+
http://www.apache.org/licenses/LICENSE-2.0
20+
21+
Unless required by applicable law or agreed to in writing, software
22+
distributed under the License is distributed on an "AS IS" BASIS,
23+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24+
See the License for the specific language governing permissions and
25+
limitations under the License. */
26+
27+
#include "mkldnn.hpp"
28+
#include "paddle/fluid/framework/tensor.h"
29+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
30+
#include "paddle/fluid/operators/sum_op.h"
31+
#include "paddle/fluid/platform/device_context.h"
32+
#include "paddle/fluid/platform/mkldnn_helper.h"
33+
34+
namespace paddle {
35+
namespace operators {
36+
37+
using paddle::framework::Tensor;
38+
using paddle::platform::MKLDNNDeviceContext;
39+
using paddle::platform::CPUDeviceContext;
40+
using framework::DataLayout;
41+
using mkldnn::memory;
42+
using mkldnn::primitive;
43+
using mkldnn::stream;
44+
using mkldnn::sum;
45+
using mkldnn::reorder;
46+
using platform::to_void_cast;
47+
48+
template <typename T>
49+
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
50+
public:
51+
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
52+
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
53+
"It must use CPUPlace.");
54+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
55+
const auto& mkldnn_engine = dev_ctx.GetEngine();
56+
auto in_vars = ctx.MultiInputVar("X");
57+
58+
const int N = in_vars.size();
59+
auto out_var = ctx.OutputVar("Out");
60+
bool in_place = out_var == in_vars[0];
61+
62+
if (out_var->IsType<framework::LoDTensor>()) {
63+
LoDTensor* output = ctx.Output<LoDTensor>("Out");
64+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
65+
66+
std::vector<int> dst_tz = framework::vectorize2int(output->dims());
67+
auto src_tz = dst_tz;
68+
memory::format output_format{memory::format::format_undef};
69+
std::vector<float> scales;
70+
std::vector<memory::primitive_desc> srcs_mpd;
71+
std::vector<mkldnn::memory> srcs_mem;
72+
73+
PADDLE_ENFORCE(in_vars[0]->IsType<LoDTensor>(),
74+
"Input[0] must be LoDTensors");
75+
auto& input0 = in_vars[0]->Get<LoDTensor>();
76+
PADDLE_ENFORCE(input0.layout() == DataLayout::kMKLDNN &&
77+
input0.format() != memory::format::format_undef,
78+
"Wrong layout/format for inputs[0]");
79+
80+
memory::format input_format = input0.format();
81+
82+
if (src_tz.size() == 1 && (input_format == memory::format::nchw ||
83+
input_format == memory::format::nhwc)) {
84+
input_format = memory::format::x;
85+
}
86+
if (src_tz.size() == 2 && (input_format == memory::format::nchw ||
87+
input_format == memory::format::nhwc)) {
88+
input_format = memory::format::nc;
89+
}
90+
91+
for (int i = in_place ? 1 : 0; i < N; i++) {
92+
PADDLE_ENFORCE(in_vars[i]->IsType<LoDTensor>(),
93+
"all inputs must be all LoDTensors");
94+
auto& input = in_vars[i]->Get<LoDTensor>();
95+
PADDLE_ENFORCE(input.layout() == DataLayout::kMKLDNN &&
96+
input.format() != memory::format::format_undef,
97+
"Wrong layout/format for inputs");
98+
99+
if (input.numel() == 0) {
100+
continue;
101+
}
102+
103+
const T* input_data = input.data<T>();
104+
105+
auto src_md =
106+
memory::desc(src_tz, memory::data_type::f32, input_format);
107+
auto src_mpd = memory::primitive_desc(src_md, mkldnn_engine);
108+
auto src_mem = memory(src_mpd, to_void_cast(input_data));
109+
srcs_mpd.push_back(src_mpd);
110+
srcs_mem.push_back(src_mem);
111+
scales.push_back(1.0);
112+
}
113+
114+
auto dst_md =
115+
memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
116+
117+
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
118+
119+
std::shared_ptr<memory> dst_mem;
120+
if (in_place) {
121+
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
122+
} else {
123+
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
124+
}
125+
std::vector<mkldnn::primitive::at> inputs;
126+
for (size_t i = 0; i < srcs_mem.size(); ++i) {
127+
inputs.push_back(srcs_mem[i]);
128+
}
129+
130+
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
131+
output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd);
132+
133+
primitive reorder_prim;
134+
std::shared_ptr<memory> target_mem;
135+
if (in_place) {
136+
output_format = input_format;
137+
target_mem.reset(new memory(
138+
{{{src_tz}, memory::data_type::f32, output_format}, mkldnn_engine},
139+
output_data));
140+
reorder_prim = reorder(*dst_mem, *target_mem);
141+
}
142+
143+
std::vector<primitive> pipeline;
144+
pipeline.push_back(sum_prim);
145+
if (in_place) pipeline.push_back(reorder_prim);
146+
stream(stream::kind::eager).submit(pipeline).wait();
147+
148+
output->set_layout(DataLayout::kMKLDNN);
149+
output->set_format(output_format);
150+
} else if (out_var->IsType<framework::SelectedRows>()) {
151+
// TODO(@mozga-intel) Add MKLDNN SelectedRows support
152+
std::unique_ptr<framework::SelectedRows> in0;
153+
if (in_place) {
154+
// If is in_place, we store the input[0] to in0
155+
auto& in_sel0 = in_vars[0]->Get<SelectedRows>();
156+
auto& rows = in_sel0.rows();
157+
in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
158+
in0->mutable_value()->ShareDataWith(in_sel0.value());
159+
}
160+
161+
auto get_selected_row = [&](size_t i) -> const SelectedRows& {
162+
if (i == 0 && in0) {
163+
return *in0.get();
164+
} else {
165+
return in_vars[i]->Get<SelectedRows>();
166+
}
167+
};
168+
auto* out = ctx.Output<SelectedRows>("Out");
169+
out->mutable_rows()->clear();
170+
auto* out_value = out->mutable_value();
171+
172+
// Runtime InferShape
173+
size_t first_dim = 0;
174+
for (int i = 0; i < N; i++) {
175+
auto& sel_row = get_selected_row(i);
176+
first_dim += sel_row.rows().size();
177+
}
178+
auto in_dim =
179+
framework::vectorize(get_selected_row(N - 1).value().dims());
180+
in_dim[0] = static_cast<int64_t>(first_dim);
181+
182+
out_value->Resize(framework::make_ddim(in_dim));
183+
184+
// if all the input sparse vars are empty, no need to
185+
// merge these vars.
186+
if (first_dim == 0UL) {
187+
return;
188+
}
189+
out_value->mutable_data<T>(ctx.GetPlace());
190+
math::SelectedRowsAddTo<CPUDeviceContext, T> functor;
191+
int64_t offset = 0;
192+
for (int i = 0; i < N; i++) {
193+
auto& sel_row = get_selected_row(i);
194+
if (sel_row.rows().size() == 0) {
195+
continue;
196+
}
197+
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
198+
functor(ctx.template device_context<CPUDeviceContext>(), sel_row,
199+
offset, out);
200+
offset += sel_row.value().numel();
201+
}
202+
} else if (out_var->IsType<framework::LoDTensorArray>()) {
203+
// TODO(@mozga-intel) Add MKLDNN LoDTensorArray support
204+
auto& out_array = *out_var->GetMutable<framework::LoDTensorArray>();
205+
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
206+
PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
207+
"Only support all inputs are TensorArray");
208+
auto& in_array = in_vars[i]->Get<framework::LoDTensorArray>();
209+
210+
for (size_t i = 0; i < in_array.size(); ++i) {
211+
if (in_array[i].numel() != 0) {
212+
if (i >= out_array.size()) {
213+
out_array.resize(i + 1);
214+
}
215+
if (out_array[i].numel() == 0) {
216+
framework::TensorCopy(in_array[i], in_array[i].place(),
217+
ctx.device_context(), &out_array[i]);
218+
out_array[i].set_lod(in_array[i].lod());
219+
} else {
220+
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
221+
auto in = EigenVector<T>::Flatten(in_array[i]);
222+
auto result = EigenVector<T>::Flatten(out_array[i]);
223+
result.device(*ctx.template device_context<MKLDNNDeviceContext>()
224+
.eigen_device()) = result + in;
225+
}
226+
}
227+
}
228+
}
229+
} else {
230+
PADDLE_THROW("Unexpected branch, output variable type is %s",
231+
out_var->Type().name());
232+
}
233+
}
234+
};
235+
236+
} // namespace operators
237+
} // namespace paddle
238+
239+
REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace,
240+
paddle::operators::SumMKLDNNOpKernel<float>);

paddle/fluid/operators/sum_op.cc

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ limitations under the License. */
1818
#include "paddle/fluid/framework/var_type_inference.h"
1919
#include "paddle/fluid/operators/detail/safe_ref.h"
2020

21+
#ifdef PADDLE_WITH_MKLDNN
22+
#include "paddle/fluid/platform/mkldnn_helper.h"
23+
#endif
24+
2125
namespace paddle {
2226
namespace operators {
2327
using framework::Tensor;
@@ -63,6 +67,18 @@ class SumOp : public framework::OperatorWithKernel {
6367
framework::OpKernelType GetExpectedKernelType(
6468
const framework::ExecutionContext& ctx) const override {
6569
auto x_vars = ctx.MultiInputVar("X");
70+
71+
framework::LibraryType library{framework::LibraryType::kPlain};
72+
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
73+
74+
#ifdef PADDLE_WITH_MKLDNN
75+
if (library == framework::LibraryType::kPlain &&
76+
platform::CanMKLDNNBeUsed(ctx)) {
77+
library = framework::LibraryType::kMKLDNN;
78+
layout = framework::DataLayout::kMKLDNN;
79+
}
80+
#endif
81+
6682
if (x_vars[0]->IsType<framework::LoDTensor>()) {
6783
int dtype = -1;
6884
for (auto& x_var : x_vars) {
@@ -80,26 +96,27 @@ class SumOp : public framework::OperatorWithKernel {
8096
"Sum operator should have at least one tensor");
8197

8298
return framework::OpKernelType(
83-
static_cast<framework::proto::VarType::Type>(dtype),
84-
ctx.device_context());
99+
static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
100+
layout, library);
85101
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
86102
for (auto& var : x_vars) {
87103
auto& value = var->Get<framework::SelectedRows>().value();
88104
if (value.IsInitialized()) {
89105
return framework::OpKernelType(framework::ToDataType(value.type()),
90-
ctx.device_context());
106+
ctx.device_context(), layout, library);
91107
}
92108
}
93109
// if input sparse vars are not initialized, use an default kernel type.
94110
return framework::OpKernelType(framework::proto::VarType::FP32,
95-
ctx.device_context());
111+
ctx.device_context(), layout, library);
96112
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
97113
for (auto& x_var : x_vars) {
98114
auto& array = x_var->Get<framework::LoDTensorArray>();
99115
for (auto& each : array) {
100116
if (each.numel() != 0) {
101117
return framework::OpKernelType(framework::ToDataType(each.type()),
102-
ctx.device_context());
118+
ctx.device_context(), layout,
119+
library);
103120
}
104121
}
105122
}
@@ -116,6 +133,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
116133
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
117134
.AsDuplicable();
118135
AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X");
136+
AddAttr<bool>("use_mkldnn",
137+
"(bool, default false) Only used in mkldnn kernel")
138+
.SetDefault(false);
119139
AddComment(R"DOC(
120140
Sum operator.
121141
@@ -132,7 +152,6 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
132152
framework::BlockDesc* block) const override {
133153
auto& inputs = op_desc.Input("X");
134154
auto var_type = framework::proto::VarType::SELECTED_ROWS;
135-
136155
for (auto& name : op_desc.Input("X")) {
137156
VLOG(10) << name << " "
138157
<< block->FindRecursiveOrCreateVar(name).GetType();
@@ -206,6 +225,7 @@ namespace ops = paddle::operators;
206225

207226
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
208227
ops::SumOpVarTypeInference);
228+
209229
REGISTER_OP_CPU_KERNEL(
210230
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
211231
ops::SumKernel<paddle::platform::CPUDeviceContext, double>,

paddle/fluid/operators/while_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,11 @@ class WhileGradOp : public framework::OperatorBase {
203203
->set_lod(inside_tensor.lod());
204204
}
205205
}
206-
207206
auto new_inside_name = cur_scope.Rename(inside_grad_name);
208207
auto sum_op = framework::OpRegistry::CreateOp(
209208
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
210-
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
209+
{{"Out", {pg_names[param_id]}}},
210+
framework::AttributeMap{{"use_mkldnn", {false}}});
211211
sum_op->Run(cur_scope, dev_place);
212212
cur_scope.Rename(new_inside_name, inside_grad_name);
213213
}

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) {
9999
memory.get_primitive_desc().desc().data.format);
100100
}
101101

102+
inline mkldnn::memory::format GetMKLDNNFormat(
103+
const mkldnn::sum::primitive_desc& memory) {
104+
return static_cast<mkldnn::memory::format>(
105+
memory.dst_primitive_desc().desc().data.format);
106+
}
107+
102108
} // namespace platform
103109
} // namespace paddle

python/paddle/fluid/backward.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def _addup_repetitive_outputs_(op_descs):
132132
for idx, op_desc in enumerate(op_descs):
133133
for var_name in op_desc.input_arg_names():
134134
if len(renamed_vars[var_name]) > 1:
135-
pending_sum_ops.append(
136-
(_create_op_desc_("sum", {"X": renamed_vars[var_name]},
137-
{"Out": [var_name]}, {}), idx))
135+
pending_sum_ops.append((_create_op_desc_(
136+
"sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]},
137+
{"use_mkldnn": False}), idx))
138138
renamed_vars[var_name] = [var_name]
139139
for var_name in op_desc.output_arg_names():
140140
if var_name == core.empty_var_name(
@@ -161,8 +161,9 @@ def _addup_repetitive_outputs_(op_descs):
161161
renamed_vars[var_name].append(new_name)
162162
for var_name, inputs in renamed_vars.iteritems():
163163
if len(inputs) > 1:
164-
pending_sum_ops.append((_create_op_desc_(
165-
"sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs)))
164+
pending_sum_ops.append(
165+
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
166+
{"use_mkldnn": False}), len(op_descs)))
166167
# sum_op descs are sorted according to their insert position
167168
for p in reversed(pending_sum_ops):
168169
op_descs.insert(p[1], p[0])

0 commit comments

Comments
 (0)