Skip to content

Commit c7b3bfc

Browse files
authored
Merge pull request #14376 from baojun-nervana/intel/ngraph_fusedop
Adding fused operator for ngraph
2 parents 83ddafb + 51a538e commit c7b3bfc

File tree

7 files changed

+420
-7
lines changed

7 files changed

+420
-7
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ cc_library(version SRCS version.cc)
136136
cc_test(version_test SRCS version_test.cc DEPS version)
137137

138138
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
139+
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
140+
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
141+
shape_inference data_transform lod_tensor profiler)
142+
139143

140144
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
141145
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
@@ -163,10 +167,10 @@ if(WITH_DISTRIBUTE)
163167
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
164168
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
165169
else()
166-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
170+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
167171
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
168172
endif()
169-
173+
170174
if (NOT WIN32)
171175
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
172176
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor

paddle/fluid/framework/executor.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/feed_fetch_method.h"
1818
#include "paddle/fluid/framework/lod_rank_table.h"
1919
#include "paddle/fluid/framework/lod_tensor_array.h"
20+
#include "paddle/fluid/framework/ngraph_operator.h"
2021
#include "paddle/fluid/framework/op_registry.h"
2122
#include "paddle/fluid/framework/reader.h"
2223
#include "paddle/fluid/operators/detail/macros.h"
@@ -25,6 +26,7 @@ limitations under the License. */
2526

2627
DECLARE_bool(benchmark);
2728
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
29+
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
2830

2931
namespace paddle {
3032
namespace framework {
@@ -81,6 +83,24 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
8183
}
8284
}
8385

86+
static void EnableFusedOp(ExecutorPrepareContext* ctx) {
87+
#ifdef PADDLE_WITH_NGRAPH
88+
VLOG(3) << "use_ngraph=True";
89+
auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_);
90+
for (auto& interval : intervals) {
91+
auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_,
92+
interval.at(0), interval.at(1));
93+
*interval[0] = std::unique_ptr<OperatorBase>(fused_op);
94+
}
95+
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
96+
ctx->ops_.erase(it->at(0) + 1, it->at(1));
97+
}
98+
#else
99+
LOG(WARNING)
100+
<< "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option";
101+
#endif
102+
}
103+
84104
Executor::Executor(const platform::Place& place) : place_(place) {}
85105

86106
void Executor::Close() {
@@ -338,6 +358,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
338358
for (auto& op_desc : block.AllOps()) {
339359
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
340360
}
361+
if (FLAGS_use_ngraph) EnableFusedOp(ctx.get());
341362
return ctx;
342363
}
343364

@@ -486,6 +507,5 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
486507
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
487508
#endif
488509
}
489-
490510
} // namespace framework
491511
} // namespace paddle
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_NGRAPH
16+
#include <algorithm>
17+
#include <functional>
18+
19+
#include "paddle/fluid/framework/ngraph_bridge.h"
20+
21+
#include "ngraph/ngraph.hpp"
22+
23+
namespace paddle {
24+
namespace framework {
25+
26+
std::map<std::string,
27+
std::function<void(const std::shared_ptr<OperatorBase>&,
28+
std::shared_ptr<std::unordered_map<
29+
std::string, std::shared_ptr<ngraph::Node>>>)>>
30+
NgraphBridge::NG_NODE_MAP = {};
31+
32+
void NgraphBridge::build_graph(const std::shared_ptr<OperatorBase>& op) {
33+
auto& op_type = op->Type();
34+
NG_NODE_MAP[op_type](op, ngb_node_map);
35+
}
36+
37+
} // namespace framework
38+
} // namespace paddle
39+
#endif
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#ifdef PADDLE_WITH_NGRAPH
18+
19+
#include <algorithm>
20+
#include <map>
21+
#include <string>
22+
#include <unordered_map>
23+
#include <vector>
24+
25+
#include "paddle/fluid/framework/operator.h"
26+
#include "paddle/fluid/platform/enforce.h"
27+
28+
#include "ngraph/ngraph.hpp"
29+
30+
namespace paddle {
31+
namespace framework {
32+
33+
class NgraphBridge {
34+
public:
35+
static std::map<
36+
std::string,
37+
std::function<void(const std::shared_ptr<OperatorBase>&,
38+
std::shared_ptr<std::unordered_map<
39+
std::string, std::shared_ptr<ngraph::Node>>>)>>
40+
NG_NODE_MAP;
41+
42+
explicit NgraphBridge(
43+
std::shared_ptr<
44+
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
45+
var_node_map)
46+
: ngb_node_map(var_node_map) {}
47+
48+
void build_graph(const std::shared_ptr<OperatorBase>& op);
49+
50+
private:
51+
std::shared_ptr<
52+
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
53+
ngb_node_map;
54+
};
55+
56+
} // namespace framework
57+
} // namespace paddle
58+
#endif
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_NGRAPH
16+
#include <glog/logging.h>
17+
18+
#include <algorithm>
19+
#include <map>
20+
21+
#include "paddle/fluid/framework/feed_fetch_type.h"
22+
#include "paddle/fluid/framework/ngraph_operator.h"
23+
#include "paddle/fluid/framework/shape_inference.h"
24+
#include "paddle/fluid/framework/var_desc.h"
25+
#include "paddle/fluid/framework/var_type.h"
26+
27+
namespace paddle {
28+
namespace framework {
29+
30+
static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
31+
{proto::VarType::FP32, ngraph::element::f32},
32+
{proto::VarType::FP64, ngraph::element::f64},
33+
{proto::VarType::INT32, ngraph::element::i32},
34+
{proto::VarType::INT64, ngraph::element::i64},
35+
{proto::VarType::BOOL, ngraph::element::boolean},
36+
};
37+
38+
typedef enum { /* nGraph support state on ops */
39+
FULL_TRAIN, /* Support full ops for train */
40+
PARTIAL_TRAIN, /* Support partial ops for train */
41+
FULL_TEST, /* Support full list of ops for test */
42+
PARTIAL_TEST /* Support partial list of ops for test */
43+
} op_state;
44+
45+
class NgraphOperator {
46+
public:
47+
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
48+
const std::vector<std::shared_ptr<OperatorBase>>& ops,
49+
const std::unordered_map<
50+
std::string, ngraph::element::Type>& var_type_map,
51+
const std::unordered_set<std::string>& persist,
52+
const std::unordered_set<std::string>& fetches,
53+
const std::unordered_set<std::string>& post_op_inputs,
54+
op_state ng_op_state)
55+
: scope_(scope),
56+
place_(place),
57+
fused_ops_(ops),
58+
var_type_map_(var_type_map),
59+
persistables_(persist),
60+
fetches_(fetches),
61+
post_op_inputs_(post_op_inputs),
62+
ng_op_state_(ng_op_state) {}
63+
64+
void Run(const Scope& scope, const platform::Place& place) const;
65+
66+
private:
67+
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
68+
func_cache;
69+
const Scope& scope_;
70+
const platform::Place& place_;
71+
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
72+
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
73+
std::unordered_set<std::string> persistables_;
74+
std::unordered_set<std::string> fetches_;
75+
std::unordered_set<std::string> post_op_inputs_;
76+
op_state ng_op_state_;
77+
};
78+
79+
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
80+
FusedOperator::FusedOpIntervals(
81+
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
82+
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
83+
intervals;
84+
if (ops->empty()) {
85+
return intervals;
86+
}
87+
size_t size = ops->size();
88+
size_t left = 0;
89+
while (left < size && ops.at(left)->Type() != kFeedOpType) {
90+
++left;
91+
}
92+
if (left == size) {
93+
return intervals;
94+
}
95+
while (left < size && ops->at(left)->Type() == kFeedOpType) {
96+
++left;
97+
}
98+
99+
size_t right = left;
100+
while (right < size && ops->at(right)->Type() != kFetchOpType) {
101+
++right;
102+
}
103+
if (right == size) {
104+
return intervals;
105+
}
106+
if (left >= right) return intervals;
107+
108+
// (left, right - 1) represents indices between feed and fetch
109+
size_t pivot = left;
110+
while (pivot < right) {
111+
auto op_type = ops->at(pivot)->Type();
112+
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) ==
113+
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) {
114+
++pivot;
115+
} else {
116+
size_t start = pivot, end = start;
117+
while (pivot < right &&
118+
(paddle::framework::NgraphBridge::NG_NODE_MAP.find(
119+
ops.at(pivot)->Type()) !=
120+
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
121+
++pivot;
122+
++end;
123+
}
124+
std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>
125+
interval = {ops->begin() + start, ops->begin() + end};
126+
intervals.push_back(interval);
127+
}
128+
} // end while
129+
130+
return intervals;
131+
}
132+
133+
FusedOperator::FusedOperator(
134+
const ProgramDesc& prog, size_t block_id,
135+
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
136+
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
137+
const std::string& type, const VariableNameMap& inputs,
138+
const VariableNameMap& outputs, const AttributeMap& attrs)
139+
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
140+
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
141+
it != end; ++it) {
142+
fused_ops_.push_back(std::move(*it));
143+
}
144+
145+
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
146+
(*it)->Type() != kFetchOpType; ++it) {
147+
for (auto& var_name_item : (*it)->Inputs()) {
148+
for (auto& var_name : var_name_item.second) {
149+
post_op_inputs_.insert(var_name);
150+
}
151+
}
152+
}
153+
154+
if ((*(start - 1))->Type() == kFeedOpType && (*end)->Type() == kFetchOpType) {
155+
is_complete = true;
156+
}
157+
158+
Process();
159+
}
160+
161+
void FusedOperator::Process() {
162+
auto& bdesc = pdesc_.Block(block_);
163+
for (auto& var : bdesc.AllVars()) {
164+
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
165+
var->GetType() == proto::VarType::LOD_TENSOR ||
166+
var->GetType() == proto::VarType::LOD_TENSOR_ARRAY)) {
167+
continue;
168+
}
169+
170+
auto var_name = var->Name();
171+
if (var->Name() == framework::kEmptyVarName) {
172+
continue;
173+
}
174+
175+
if (var_name != "fetch" && var_name != "feed") {
176+
auto pd_type = var->GetDataType();
177+
if (pd2ng_type_map.find(pd_type) == pd2ng_type_map.end()) {
178+
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
179+
var_name);
180+
}
181+
var_type_map_[var_name] = pd2ng_type_map[pd_type];
182+
}
183+
184+
if (var->Persistable()) {
185+
persistables_.insert(var->Name());
186+
}
187+
}
188+
189+
for (auto* op : bdesc.AllOps()) {
190+
if (op->Type() == kFetchOpType) {
191+
std::string fetch_target_name = op->Input("X")[0];
192+
fetches_.insert(fetch_target_name);
193+
}
194+
}
195+
}
196+
197+
void FusedOperator::RunImpl(const Scope& scope,
198+
const platform::Place& place) const {
199+
op_state ng_op_state = PARTIAL_TEST;
200+
auto& bdesc = pdesc_.Block(block_);
201+
for (auto* op : bdesc.AllOps()) {
202+
if (op->Type().find("_grad") != std::string::npos) {
203+
ng_op_state = PARTIAL_TRAIN;
204+
break;
205+
}
206+
}
207+
208+
if (is_full) {
209+
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
210+
}
211+
212+
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
213+
persistables_, fetches_, post_op_inputs_,
214+
ng_op_state);
215+
ngraph_op.Run(scope, place);
216+
}
217+
218+
} // namespace framework
219+
} // namespace paddle
220+
#endif

0 commit comments

Comments
 (0)