Skip to content

Commit 8e2dea5

Browse files
authored
Merge pull request #15538 from baojun-nervana/mv_ng_bridge_file
move ngraph_bridge to ngraph directory
2 parents e2818c8 + 8e9308a commit 8e2dea5

File tree

5 files changed

+31
-35
lines changed

5 files changed

+31
-35
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version)
129129

130130
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
131131

132-
if(WITH_NGRAPH)
133-
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
134-
endif(WITH_NGRAPH)
135-
136132
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
137133
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
138134

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
if(WITH_NGRAPH)
2+
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
23
cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto)
34
op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context)
45
endif()

paddle/fluid/framework/ngraph_bridge.cc renamed to paddle/fluid/operators/ngraph/ngraph_bridge.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,39 +17,39 @@ limitations under the License. */
1717
#include <vector>
1818

1919
#include "ngraph/ngraph.hpp"
20-
#include "paddle/fluid/framework/ngraph_bridge.h"
21-
#include "paddle/fluid/framework/operator.h"
20+
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
2221
#include "paddle/fluid/operators/ngraph/ngraph_ops.h"
2322
#include "paddle/fluid/platform/enforce.h"
2423
#include "paddle/fluid/platform/ngraph_helper.h"
2524

2625
namespace paddle {
27-
namespace framework {
26+
namespace operators {
2827

2928
namespace NG_OPS = paddle::operators::ngraphs;
3029
std::map<std::string,
31-
std::function<void(const std::shared_ptr<OperatorBase>&,
30+
std::function<void(const std::shared_ptr<framework::OperatorBase>&,
3231
std::shared_ptr<std::unordered_map<
3332
std::string, std::shared_ptr<ngraph::Node>>>)>>
3433
NgraphBridge::NG_NODE_MAP = {
3534
{"elementwise_add", NG_OPS::BuildElementwiseAddNode},
3635
{"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode},
37-
{"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode},
38-
{"mean", paddle::operators::ngraphs::BuildMeanNode},
39-
{"mean_grad", paddle::operators::ngraphs::BuildMeanGradNode},
40-
{"mul", paddle::operators::ngraphs::BuildMulNode},
41-
{"mul_grad", paddle::operators::ngraphs::BuildMulGradNode},
42-
{"softmax", paddle::operators::ngraphs::BuildSoftmaxNode},
43-
{"softmax_grad", paddle::operators::ngraphs::BuildSoftmaxGradNode},
44-
{"scale", paddle::operators::ngraphs::BuildScaleNode},
45-
{"relu", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Relu>},
46-
{"tanh", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Tanh>},
47-
{"top_k", paddle::operators::ngraphs::BuildTopKNode}};
48-
49-
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
36+
{"fill_constant", NG_OPS::BuildFillConstantNode},
37+
{"mean", NG_OPS::BuildMeanNode},
38+
{"mean_grad", NG_OPS::BuildMeanGradNode},
39+
{"mul", NG_OPS::BuildMulNode},
40+
{"mul_grad", NG_OPS::BuildMulGradNode},
41+
{"softmax", NG_OPS::BuildSoftmaxNode},
42+
{"softmax_grad", NG_OPS::BuildSoftmaxGradNode},
43+
{"scale", NG_OPS::BuildScaleNode},
44+
{"relu", NG_OPS::BuildUnaryNode<ngraph::op::Relu>},
45+
{"tanh", NG_OPS::BuildUnaryNode<ngraph::op::Tanh>},
46+
{"top_k", NG_OPS::BuildTopKNode}};
47+
48+
void NgraphBridge::BuildNgNode(
49+
const std::shared_ptr<framework::OperatorBase>& op) {
5050
auto& op_type = op->Type();
5151
NG_NODE_MAP[op_type](op, ngb_node_map_);
5252
}
5353

54-
} // namespace framework
54+
} // namespace operators
5555
} // namespace paddle

paddle/fluid/framework/ngraph_bridge.h renamed to paddle/fluid/operators/ngraph/ngraph_bridge.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ limitations under the License. */
2121

2222
#include "ngraph/node.hpp"
2323

24-
namespace paddle {
25-
namespace framework {
24+
#include "paddle/fluid/framework/operator.h"
2625

27-
class OperatorBase;
26+
namespace paddle {
27+
namespace operators {
2828

2929
class NgraphBridge {
3030
public:
3131
static std::map<
3232
std::string,
33-
std::function<void(const std::shared_ptr<OperatorBase>&,
33+
std::function<void(const std::shared_ptr<framework::OperatorBase>&,
3434
std::shared_ptr<std::unordered_map<
3535
std::string, std::shared_ptr<ngraph::Node>>>)>>
3636
NG_NODE_MAP;
@@ -41,13 +41,13 @@ class NgraphBridge {
4141
var_node_map)
4242
: ngb_node_map_(var_node_map) {}
4343

44-
void BuildNgNode(const std::shared_ptr<OperatorBase>& op);
44+
void BuildNgNode(const std::shared_ptr<framework::OperatorBase>& op);
4545

4646
private:
4747
std::shared_ptr<
4848
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
4949
ngb_node_map_;
5050
};
5151

52-
} // namespace framework
52+
} // namespace operators
5353
} // namespace paddle

paddle/fluid/operators/ngraph/ngraph_engine.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ limitations under the License. */
2424
#include "paddle/fluid/framework/feed_fetch_type.h"
2525
#include "paddle/fluid/framework/framework.pb.h"
2626
#include "paddle/fluid/framework/lod_tensor.h"
27-
#include "paddle/fluid/framework/ngraph_bridge.h"
2827
#include "paddle/fluid/framework/op_desc.h"
2928
#include "paddle/fluid/framework/op_registry.h"
3029
#include "paddle/fluid/framework/var_desc.h"
3130
#include "paddle/fluid/framework/var_type.h"
31+
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
3232
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
3333

3434
namespace paddle {
@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
8888
int pivot = left;
8989
while (pivot < right) {
9090
auto op_type = ops.at(pivot)->Type();
91-
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) ==
92-
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) {
91+
if (NgraphBridge::NG_NODE_MAP.find(op_type) ==
92+
NgraphBridge::NG_NODE_MAP.end()) {
9393
++pivot;
9494
} else {
9595
int start = pivot, end = start;
9696
while (pivot < right &&
97-
(paddle::framework::NgraphBridge::NG_NODE_MAP.find(
98-
ops.at(pivot)->Type()) !=
99-
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
97+
(NgraphBridge::NG_NODE_MAP.find(ops.at(pivot)->Type()) !=
98+
NgraphBridge::NG_NODE_MAP.end())) {
10099
++pivot;
101100
++end;
102101
}
@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() {
283282
}
284283
}
285284
}
286-
framework::NgraphBridge ngb(var_node_map_);
285+
NgraphBridge ngb(var_node_map_);
287286
for (auto& op : fused_ops_) {
288287
ngb.BuildNgNode(op);
289288
}

0 commit comments

Comments
 (0)