Skip to content

Commit 2202f22

Browse files
committed
Merge branch 'windows/build' into windows/online
test=develop
2 parents e2a1cd1 + 0ef2a37 commit 2202f22

File tree

20 files changed

+605
-60
lines changed

20 files changed

+605
-60
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
7171
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
7272
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
7373
option(WITH_ANAKIN "Compile with Anakin library" OFF)
74+
option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF)
75+
option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON)
7476
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
7577
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
7678
option(ON_INFER "Turn on inference optimization." OFF)

cmake/external/anakin.cmake

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,21 @@ ExternalProject_Add(
5858
-DPROTOBUF_ROOT=${THIRD_PARTY_PATH}/install/protobuf
5959
-DMKLML_ROOT=${THIRD_PARTY_PATH}/install/mklml
6060
-DENABLE_OP_TIMER=${ANAKIN_ENABLE_OP_TIMER}
61+
-DBUILD_FAT_BIN=${ANAKIN_BUILD_FAT_BIN}
62+
-DBUILD_CROSS_PLANTFORM=${ANAKIN_BUILD_CROSS_PLANTFORM}
6163
${EXTERNAL_OPTIONAL_ARGS}
6264
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ANAKIN_INSTALL_DIR}
6365
)
6466

6567
message(STATUS "Anakin for inference is enabled")
6668
message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}")
67-
69+
add_dependencies(extern_anakin protobuf mklml)
6870
add_library(anakin_shared SHARED IMPORTED GLOBAL)
6971
set_property(TARGET anakin_shared PROPERTY IMPORTED_LOCATION ${ANAKIN_SHARED_LIB})
70-
add_dependencies(anakin_shared extern_anakin protobuf mklml)
72+
add_dependencies(anakin_shared extern_anakin)
7173

7274
add_library(anakin_saber SHARED IMPORTED GLOBAL)
7375
set_property(TARGET anakin_saber PROPERTY IMPORTED_LOCATION ${ANAKIN_SABER_LIB})
74-
add_dependencies(anakin_saber extern_anakin protobuf mklml)
76+
add_dependencies(anakin_saber extern_anakin)
7577

7678
list(APPEND external_project_dependencies anakin_shared anakin_saber)

paddle/fluid/framework/CMakeLists.txt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ cc_library(version SRCS version.cc)
136136
cc_test(version_test SRCS version_test.cc DEPS version)
137137

138138
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
139+
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
140+
if(NOT WIN32)
141+
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
142+
shape_inference data_transform lod_tensor profiler)
143+
endif(NOT WIN32)
139144

140145
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
141146
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
@@ -170,10 +175,14 @@ if(WITH_DISTRIBUTE)
170175
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
171176
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
172177
else()
173-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
178+
if(NOT WIN32)
179+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
180+
else(NOT WIN32)
181+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
182+
endif(NOT WIN32)
174183
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
175184
endif()
176-
185+
177186
if (NOT WIN32)
178187
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
179188
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor

paddle/fluid/framework/executor.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/feed_fetch_method.h"
1818
#include "paddle/fluid/framework/lod_rank_table.h"
1919
#include "paddle/fluid/framework/lod_tensor_array.h"
20+
#include "paddle/fluid/framework/ngraph_operator.h"
2021
#include "paddle/fluid/framework/op_registry.h"
2122
#include "paddle/fluid/framework/reader.h"
2223
#include "paddle/fluid/operators/detail/macros.h"
@@ -25,6 +26,7 @@ limitations under the License. */
2526

2627
DECLARE_bool(benchmark);
2728
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
29+
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
2830

2931
namespace paddle {
3032
namespace framework {
@@ -81,6 +83,24 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
8183
}
8284
}
8385

86+
static void EnableFusedOp(ExecutorPrepareContext* ctx) {
87+
#ifdef PADDLE_WITH_NGRAPH
88+
VLOG(3) << "use_ngraph=True";
89+
auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_);
90+
for (auto& interval : intervals) {
91+
auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_,
92+
interval.at(0), interval.at(1));
93+
*interval[0] = std::unique_ptr<OperatorBase>(fused_op);
94+
}
95+
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
96+
ctx->ops_.erase(it->at(0) + 1, it->at(1));
97+
}
98+
#else
99+
LOG(WARNING)
100+
<< "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option";
101+
#endif
102+
}
103+
84104
Executor::Executor(const platform::Place& place) : place_(place) {}
85105

86106
void Executor::Close() {
@@ -338,6 +358,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
338358
for (auto& op_desc : block.AllOps()) {
339359
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
340360
}
361+
if (FLAGS_use_ngraph) EnableFusedOp(ctx.get());
341362
return ctx;
342363
}
343364

@@ -486,6 +507,5 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
486507
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
487508
#endif
488509
}
489-
490510
} // namespace framework
491511
} // namespace paddle
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_NGRAPH
16+
#include <algorithm>
17+
#include <functional>
18+
19+
#include "paddle/fluid/framework/ngraph_bridge.h"
20+
21+
#include "ngraph/ngraph.hpp"
22+
23+
namespace paddle {
24+
namespace framework {
25+
26+
std::map<std::string,
27+
std::function<void(const std::shared_ptr<OperatorBase>&,
28+
std::shared_ptr<std::unordered_map<
29+
std::string, std::shared_ptr<ngraph::Node>>>)>>
30+
NgraphBridge::NG_NODE_MAP = {};
31+
32+
void NgraphBridge::build_graph(const std::shared_ptr<OperatorBase>& op) {
33+
auto& op_type = op->Type();
34+
NG_NODE_MAP[op_type](op, ngb_node_map);
35+
}
36+
37+
} // namespace framework
38+
} // namespace paddle
39+
#endif
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#ifdef PADDLE_WITH_NGRAPH
18+
19+
#include <algorithm>
20+
#include <map>
21+
#include <string>
22+
#include <unordered_map>
23+
#include <vector>
24+
25+
#include "paddle/fluid/framework/operator.h"
26+
#include "paddle/fluid/platform/enforce.h"
27+
28+
#include "ngraph/ngraph.hpp"
29+
30+
namespace paddle {
31+
namespace framework {
32+
33+
class NgraphBridge {
34+
public:
35+
static std::map<
36+
std::string,
37+
std::function<void(const std::shared_ptr<OperatorBase>&,
38+
std::shared_ptr<std::unordered_map<
39+
std::string, std::shared_ptr<ngraph::Node>>>)>>
40+
NG_NODE_MAP;
41+
42+
explicit NgraphBridge(
43+
std::shared_ptr<
44+
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
45+
var_node_map)
46+
: ngb_node_map(var_node_map) {}
47+
48+
void build_graph(const std::shared_ptr<OperatorBase>& op);
49+
50+
private:
51+
std::shared_ptr<
52+
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
53+
ngb_node_map;
54+
};
55+
56+
} // namespace framework
57+
} // namespace paddle
58+
#endif

0 commit comments

Comments
 (0)