Skip to content

Commit ea3538d

Browse files
Added fused operator
test=develop
1 parent 9a6e239 commit ea3538d

File tree

7 files changed

+416
-7
lines changed

7 files changed

+416
-7
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ cc_library(version SRCS version.cc)
136136
cc_test(version_test SRCS version_test.cc DEPS version)
137137

138138
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
139+
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
140+
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
141+
shape_inference data_transform lod_tensor profiler)
142+
139143

140144
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
141145
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
@@ -163,10 +167,10 @@ if(WITH_DISTRIBUTE)
163167
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
164168
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
165169
else()
166-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
170+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
167171
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
168172
endif()
169-
173+
170174
if (NOT WIN32)
171175
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
172176
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor

paddle/fluid/framework/executor.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/feed_fetch_method.h"
1818
#include "paddle/fluid/framework/lod_rank_table.h"
1919
#include "paddle/fluid/framework/lod_tensor_array.h"
20+
#include "paddle/fluid/framework/ngraph_operator.h"
2021
#include "paddle/fluid/framework/op_registry.h"
2122
#include "paddle/fluid/framework/reader.h"
2223
#include "paddle/fluid/operators/detail/macros.h"
@@ -25,6 +26,7 @@ limitations under the License. */
2526

2627
DECLARE_bool(benchmark);
2728
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
29+
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
2830

2931
namespace paddle {
3032
namespace framework {
@@ -81,6 +83,24 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
8183
}
8284
}
8385

86+
static void EnableFusedOp(ExecutorPrepareContext* ctx) {
87+
#ifdef PADDLE_WITH_NGRAPH
88+
VLOG(3) << "use_ngraph=True";
89+
auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_);
90+
for (auto& interval : intervals) {
91+
auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_,
92+
interval.at(0), interval.at(1));
93+
*interval[0] = std::unique_ptr<OperatorBase>(fused_op);
94+
}
95+
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
96+
ctx->ops_.erase(it->at(0) + 1, it->at(1));
97+
}
98+
#else
99+
LOG(WARNING)
100+
<< "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option";
101+
#endif
102+
}
103+
84104
Executor::Executor(const platform::Place& place) : place_(place) {}
85105

86106
void Executor::Close() {
@@ -338,6 +358,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
338358
for (auto& op_desc : block.AllOps()) {
339359
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
340360
}
361+
if (FLAGS_use_ngraph) EnableFusedOp(ctx.get());
341362
return ctx;
342363
}
343364

@@ -485,6 +506,5 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
485506
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
486507
#endif
487508
}
488-
489509
} // namespace framework
490510
} // namespace paddle
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_NGRAPH
16+
#include <algorithm>
17+
#include <functional>
18+
19+
#include "paddle/fluid/framework/ngraph_bridge.h"
20+
21+
#include "ngraph/ngraph.hpp"
22+
23+
namespace paddle {
24+
namespace framework {
25+
26+
std::map<std::string,
27+
std::function<void(const std::shared_ptr<OperatorBase>&,
28+
std::shared_ptr<std::unordered_map<
29+
std::string, std::shared_ptr<ngraph::Node>>>)>>
30+
NgraphBridge::NG_NODE_MAP = {};
31+
32+
void NgraphBridge::build_graph(const std::shared_ptr<OperatorBase>& op) {
33+
auto& op_type = op->Type();
34+
NG_NODE_MAP[op_type](op, ngb_node_map);
35+
}
36+
37+
} // namespace framework
38+
} // namespace paddle
39+
#endif
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#ifdef PADDLE_WITH_NGRAPH
18+
19+
#include <algorithm>
20+
#include <map>
21+
#include <string>
22+
#include <unordered_map>
23+
#include <vector>
24+
25+
#include "paddle/fluid/framework/operator.h"
26+
#include "paddle/fluid/platform/enforce.h"
27+
28+
#include "ngraph/ngraph.hpp"
29+
30+
namespace paddle {
31+
namespace framework {
32+
33+
class NgraphBridge {
34+
public:
35+
static std::map<
36+
std::string,
37+
std::function<void(const std::shared_ptr<OperatorBase>&,
38+
std::shared_ptr<std::unordered_map<
39+
std::string, std::shared_ptr<ngraph::Node>>>)>>
40+
NG_NODE_MAP;
41+
42+
explicit NgraphBridge(
43+
std::shared_ptr<
44+
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
45+
var_node_map)
46+
: ngb_node_map(var_node_map) {}
47+
48+
void build_graph(const std::shared_ptr<OperatorBase>& op);
49+
50+
private:
51+
std::shared_ptr<
52+
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
53+
ngb_node_map;
54+
};
55+
56+
} // namespace framework
57+
} // namespace paddle
58+
#endif
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_NGRAPH
16+
#include <glog/logging.h>
17+
18+
#include <algorithm>
19+
#include <map>
20+
21+
#include "paddle/fluid/framework/feed_fetch_type.h"
22+
#include "paddle/fluid/framework/ngraph_operator.h"
23+
#include "paddle/fluid/framework/shape_inference.h"
24+
#include "paddle/fluid/framework/var_desc.h"
25+
#include "paddle/fluid/framework/var_type.h"
26+
27+
namespace paddle {
28+
namespace framework {
29+
30+
static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
31+
{proto::VarType::FP32, ngraph::element::f32},
32+
{proto::VarType::FP64, ngraph::element::f64},
33+
{proto::VarType::INT32, ngraph::element::i32},
34+
{proto::VarType::INT64, ngraph::element::i64},
35+
{proto::VarType::BOOL, ngraph::element::boolean},
36+
};
37+
38+
class NgraphOperator {
39+
public:
40+
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
41+
const std::vector<std::shared_ptr<OperatorBase>>& ops,
42+
const std::unordered_map<
43+
std::string, ngraph::element::Type>& var_type_map,
44+
const std::unordered_set<std::string>& persist,
45+
const std::unordered_set<std::string>& fetches,
46+
const std::unordered_set<std::string>& post_op_inputs,
47+
int is_test_or_train)
48+
: scope(scope),
49+
place(place),
50+
fused_ops(ops),
51+
var_type_map(var_type_map),
52+
persistables(persist),
53+
fetches(fetches),
54+
post_op_inputs(post_op_inputs),
55+
is_test_or_train(is_test_or_train) {}
56+
57+
void Run(const Scope& scope, const platform::Place& place) const;
58+
59+
private:
60+
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
61+
func_cache;
62+
const Scope& scope;
63+
const platform::Place& place;
64+
std::vector<std::shared_ptr<OperatorBase>> fused_ops;
65+
std::unordered_map<std::string, ngraph::element::Type> var_type_map;
66+
std::unordered_set<std::string> persistables;
67+
std::unordered_set<std::string> fetches;
68+
std::unordered_set<std::string> post_op_inputs;
69+
// 0 = default; 1 = (is_test && not is_complete)
70+
// 2 = (is_test && is_complete)
71+
// 3 = (is_training && not is_complete)
72+
// 4 = (is_training && is_complete)
73+
int is_test_or_train;
74+
};
75+
76+
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
77+
FusedOperator::FusedOpIntervals(
78+
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
79+
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
80+
intervals;
81+
if (ops->empty()) {
82+
return intervals;
83+
}
84+
size_t size = ops->size();
85+
size_t left = 0;
86+
while (left < size && ops.at(left)->Type() != kFeedOpType) {
87+
++left;
88+
}
89+
if (left == size) {
90+
return intervals;
91+
}
92+
while (left < size && ops->at(left)->Type() == kFeedOpType) {
93+
++left;
94+
}
95+
96+
size_t right = left;
97+
while (right < size && ops->at(right)->Type() != kFetchOpType) {
98+
++right;
99+
}
100+
if (right == size) {
101+
return intervals;
102+
}
103+
if (left >= right) return intervals;
104+
105+
// (left, right - 1) represents indices between feed and fetch
106+
size_t pivot = left;
107+
while (pivot < right) {
108+
auto op_type = ops->at(pivot)->Type();
109+
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) ==
110+
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) {
111+
++pivot;
112+
} else {
113+
size_t start = pivot, end = start;
114+
while (pivot < right &&
115+
(paddle::framework::NgraphBridge::NG_NODE_MAP.find(
116+
ops.at(pivot)->Type()) !=
117+
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
118+
++pivot;
119+
++end;
120+
}
121+
std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>
122+
interval = {ops->begin() + start, ops->begin() + end};
123+
intervals.push_back(interval);
124+
}
125+
} // end while
126+
127+
return intervals;
128+
}
129+
130+
FusedOperator::FusedOperator(
131+
const ProgramDesc& prog, size_t block_id,
132+
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
133+
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
134+
const std::string& type = "fused_op", const VariableNameMap& inputs = {},
135+
const VariableNameMap& outputs = {}, const AttributeMap& attrs = {})
136+
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
137+
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
138+
it != end; ++it) {
139+
fused_ops.push_back(std::move(*it));
140+
}
141+
142+
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
143+
(*it)->Type() != kFetchOpType; ++it) {
144+
for (auto& var_name_item : (*it)->Inputs()) {
145+
for (auto& var_name : var_name_item.second) {
146+
post_op_inputs.insert(var_name);
147+
}
148+
}
149+
}
150+
151+
if ((*(start - 1))->Type() == kFeedOpType && (*end)->Type() == kFetchOpType) {
152+
is_complete = true;
153+
}
154+
155+
process();
156+
}
157+
158+
void FusedOperator::process() {
159+
auto& bdesc = pdesc.Block(block);
160+
for (auto& var : bdesc.AllVars()) {
161+
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
162+
var->GetType() == proto::VarType::LOD_TENSOR ||
163+
var->GetType() == proto::VarType::LOD_TENSOR_ARRAY)) {
164+
continue;
165+
}
166+
167+
auto var_name = var->Name();
168+
if (var->Name() == framework::kEmptyVarName) {
169+
continue;
170+
}
171+
172+
if (var_name != "fetch" && var_name != "feed") {
173+
auto pd_type = var->GetDataType();
174+
if (pd2ng_type_map.find(pd_type) == pd2ng_type_map.end()) {
175+
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
176+
var_name);
177+
}
178+
var_type_map[var_name] = pd2ng_type_map[pd_type];
179+
}
180+
181+
if (var->Persistable()) {
182+
persistables.insert(var->Name());
183+
}
184+
}
185+
186+
for (auto* op : bdesc.AllOps()) {
187+
if (op->Type() == kFetchOpType) {
188+
std::string fetch_target_name = op->Input("X")[0];
189+
fetches.insert(fetch_target_name);
190+
}
191+
}
192+
}
193+
194+
void FusedOperator::RunImpl(const Scope& scope,
195+
const platform::Place& place) const {
196+
int is_test_or_train = 1;
197+
auto& bdesc = pdesc.Block(block);
198+
for (auto* op : bdesc.AllOps()) {
199+
if (op->Type().find("_grad") != std::string::npos) {
200+
is_test_or_train = 3;
201+
break;
202+
}
203+
}
204+
205+
if (is_complete) {
206+
is_test_or_train = is_test_or_train == 1 ? 2 : 4;
207+
}
208+
209+
NgraphOperator ngraph_op(scope, place, fused_ops, var_type_map, persistables,
210+
fetches, post_op_inputs, is_test_or_train);
211+
ngraph_op.Run(scope, place);
212+
}
213+
214+
} // namespace framework
215+
} // namespace paddle
216+
#endif

0 commit comments

Comments
 (0)