Skip to content

Commit f4bcee1

Browse files
committed
Merge branch 'develop' into anakin_test
2 parents 94042cc + 772ceee commit f4bcee1

38 files changed

+2122
-101
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ include(external/any) # download libn::any
175175
include(external/eigen) # download eigen3
176176
include(external/pybind11) # download pybind11
177177
include(external/cares)
178+
include(external/cub)
178179

179180
if(WITH_DISTRIBUTE)
180181
if(WITH_GRPC)

cmake/external/cub.cmake

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
if(NOT WITH_GPU)
2+
return()
3+
endif()
4+
5+
include(ExternalProject)
6+
7+
set(CUB_SOURCE_DIR ${THIRD_PARTY_PATH}/cub)
8+
set(CUB_INCLUDE_DIR ${CUB_SOURCE_DIR}/src/extern_cub)
9+
10+
include_directories(${CUB_INCLUDE_DIR})
11+
12+
ExternalProject_Add(
13+
extern_cub
14+
${EXTERNAL_PROJECT_LOG_ARGS}
15+
GIT_REPOSITORY "https://github.com/NVlabs/cub.git"
16+
GIT_TAG "v1.8.0"
17+
PREFIX ${CUB_SOURCE_DIR}
18+
UPDATE_COMMAND ""
19+
CONFIGURE_COMMAND ""
20+
BUILD_COMMAND ""
21+
INSTALL_COMMAND ""
22+
TEST_COMMAND ""
23+
)
24+
25+
if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
26+
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cub_dummy.c)
27+
file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
28+
add_library(cub STATIC ${dummyfile})
29+
else()
30+
add_library(cub INTERFACE)
31+
endif()
32+
33+
add_dependencies(cub extern_cub)
34+
35+
LIST(APPEND externl_project_dependencies cub)

doc/survey/op_fusion_design.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Operator fusion
2+
Fusing multiple operators together is an important method to optimize the program execution, particularly for GPU or other specialized accelerators. An obvious benefit is to avoid the overhead of saving the intermediate result back into global memory.
3+
4+
There are generally two ways to fuse operators, fusing directly connected operators and fusing non directly connected operators. The first method is mainly used by [NNVM Compiler](https://github.com/dmlc/tvm/) and [XLA](https://www.tensorflow.org/performance/xla/). The second method is mainly used by Dynet and TensorFlow Fold to do auto-batching. The principle of fusing operator is according to some rules to combine multiple operations into one, for example, `Y = X * W` and `Z = Y + B` can be fused to `Z = X * W + B`, and `Y1 = X1 * W` and `Y2 = X2 * W` can be fused to `[Y1;Y2] = [X1;X2] * W`. In order to get a short-term profit, we decided to try to manually specify these rules.
5+
6+
## Challenge
7+
The challenge of fusing operators is:
8+
- how to make the rules.
9+
- how to implement these rules efficiently.
10+
11+
### How to make the rules?
12+
13+
The problem of determining the best single location for a fusion operator is an NP-hard combinatorial problem. After analysis the operators of the DL model, we found there are two group of operators can be fused explicitly, one is the simple and adjacent operations, for example, `tmp = x + y` and `z = Relu(tmp)`, and the other is the operators that have the same function, for example, a serials of `SGD` or `Momentum`. They usually appear in the model in a large number. So we should think about how to fuse them separately first.
14+
15+
### How to implement these rules efficiently?
16+
#### How to fuse the adjacent operations efficiently?
17+
Here we use a template function to represent the fused operations. The pros of using a template function are that it is simple and efficient, and the cons are that it is not easy to expand, and it can only be used to express some simple operations. So taking into account our current needs, the template function is more appropriate.
18+
19+
#### How to fuse the operators that have the same function efficiently?
20+
We take SGD operator as an example, the training model may have hundreds of parameters and correspondingly have the same number of SGD operators. The expression(`w = w - lr*w_g`) of those operators is the same, so during of training, the executor will execute this expression hundreds time in CPU or other specialized accelerators. If we can fuse them and make the address of all `w` and all `w_g` continuous respectively, we only need execute one time. For some accelerators, the time of launching kernel is not neglected, so the time of hundreds of times of launching and executing kernel may be larger than launching and executing only once. There usually are many operators that similar to `SGD` in the DL model, such as `AllReduce` and `FC`.

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ paddle.fluid.contrib.BeamSearchDecoder.decode ArgSpec(args=['self'], varargs=Non
336336
paddle.fluid.contrib.BeamSearchDecoder.early_stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
337337
paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', 'is_ids', 'is_scores'], varargs=None, keywords=None, defaults=(False, False))
338338
paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None)
339+
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
339340
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
340341
paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None)
341342
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
275275
if (strategy_.gradient_scale_ !=
276276
BuildStrategy::GradientScaleStrategy::kCustomized) {
277277
// TODO(paddle-dev): Why is there no input for this op_handle?
278-
CreateScaleLossGradOp(&result);
278+
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
279+
CreateScaleLossGradOp(&result, loss_grad_name);
279280
}
280281
// This assumes the backward generating code will ensure IsScaleLossOp
281282
// is true only for the op that scale the final scalar loss.
@@ -535,7 +536,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
535536
return got == sharded_var_device.end() ? -1 : got->second;
536537
}
537538

538-
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
539+
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
540+
ir::Graph *result, const std::string &loss_grad_name) const {
539541
for (size_t i = 0; i < places_.size(); ++i) {
540542
// Insert ScaleCost OpHandle
541543
#ifdef PADDLE_WITH_CUDA
@@ -558,10 +560,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
558560
// loss->pending_ops_.emplace_back(op_handle);
559561
// op_handle->inputs_.emplace_back(loss);
560562

561-
CreateOpOutput(result, op_handle,
562-
result->CreateEmptyNode(GradVarName(loss_var_name_),
563-
ir::Node::Type::kVariable),
564-
places_[i], i);
563+
CreateOpOutput(
564+
result, op_handle,
565+
result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
566+
places_[i], i);
565567
}
566568
}
567569

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
7575
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
7676
size_t num_places) const;
7777

78-
void CreateScaleLossGradOp(ir::Graph *result) const;
78+
void CreateScaleLossGradOp(ir::Graph *result,
79+
const std::string &loss_grad_name) const;
80+
7981
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
8082
int dst_dev_id) const;
8183
void CreateComputationalOp(ir::Graph *result, ir::Node *node,

paddle/fluid/framework/executor.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
330330
}
331331

332332
for (auto& op : ctx->ops_) {
333-
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
334333
op->Run(*local_scope, place_);
335-
// NOTE! Please do not delete this line, it's usefull because the debug
336-
// string before and after op.run are different, after run the output
337-
// will have right shape which is usefull for debug.
338-
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
339334

340335
if (FLAGS_benchmark) {
341336
VLOG(2) << "Memory used after operator " + op->Type() + " running: "

paddle/fluid/framework/operator.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
127127
}
128128

129129
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
130-
VLOG(10) << "- " << DebugStringEx(&scope);
130+
VLOG(4) << place << " " << DebugStringEx(&scope);
131131
if (platform::is_gpu_place(place)) {
132132
#ifndef PADDLE_WITH_CUDA
133133
PADDLE_THROW("Cannot run operator on place %s", place);
@@ -139,7 +139,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
139139
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
140140
platform::RecordEvent record_event(Type(), pool.Get(place));
141141
RunImpl(scope, place);
142-
VLOG(10) << "+ " << DebugStringEx(&scope);
142+
VLOG(3) << place << " " << DebugStringEx(&scope);
143143
}
144144

145145
bool OperatorBase::HasInputs(const std::string& name) const {
@@ -778,6 +778,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
778778
const ExecutionContext& ctx) const {
779779
auto& scope = ctx.scope();
780780
int data_type = -1;
781+
std::string last_input_name;
781782
for (auto& input : this->inputs_) {
782783
for (auto& ipt_name : input.second) {
783784
auto* var = scope.FindVar(ipt_name);
@@ -794,9 +795,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
794795
int tmp = static_cast<int>(ToDataType(t->type()));
795796
PADDLE_ENFORCE(
796797
tmp == data_type || data_type == -1,
797-
"DataType of Paddle Op %s must be the same. Get %d != %d", Type(),
798-
data_type, tmp);
798+
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
799+
Type(), last_input_name, data_type, ipt_name, tmp);
799800
data_type = tmp;
801+
last_input_name = ipt_name;
800802
}
801803
}
802804
}

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
namespace paddle {
2626

27-
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
27+
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true,
2828
"Enable subgraph to TensorRT engine for acceleration");
2929

3030
DEFINE_string(inference_analysis_graphviz_log_root, "./",
@@ -42,10 +42,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
4242
// TODO(Superjomn) set the key with pass reprs.
4343
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
4444
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) {
45-
auto trt_teller = [](const Node* node) {
45+
auto trt_teller = [&](const Node* node) {
46+
std::unordered_set<std::string> teller_set(
47+
{"elementwise_add", "mul", "conv2d", "pool2d", "relu"});
4648
if (!node->IsFunction()) return false;
47-
return static_cast<const Function*>(node)->func_type() == "mul";
49+
50+
const auto* func = static_cast<const Function*>(node);
51+
if (teller_set.count(func->func_type()))
52+
return true;
53+
else {
54+
return false;
55+
}
4856
};
57+
4958
AddPass("tensorrt-subgraph-marker",
5059
new TensorRTSubgraphNodeMarkPass(trt_teller));
5160
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
337337
std::vector<Node *>(outputs.begin(), outputs.end()));
338338
}
339339

340+
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
341+
std::vector<Node *> op_nodes;
342+
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
343+
if (node.type() == Node::Type::kValue || node.deleted()) {
344+
continue;
345+
}
346+
op_nodes.push_back(&node);
347+
}
348+
size_t op_num = op_nodes.size();
349+
for (size_t i = 0; i < op_num; i++) {
350+
if (op_nodes[i]->type() == Node::Type::kFunction) continue;
351+
std::unordered_set<std::string> follow_up_input_names;
352+
for (size_t j = i + 1; j < op_num; j++) {
353+
for (auto *in : op_nodes[j]->inlinks) {
354+
follow_up_input_names.insert(in->name());
355+
}
356+
}
357+
std::vector<Node *> filtered_subgraph_outlinks;
358+
for (auto *out : op_nodes[i]->outlinks) {
359+
if (follow_up_input_names.count(out->name())) {
360+
filtered_subgraph_outlinks.push_back(out);
361+
}
362+
}
363+
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
364+
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
365+
}
366+
}
367+
340368
} // namespace analysis
341369
} // namespace inference
342370
} // namespace paddle

0 commit comments

Comments
 (0)