Skip to content

Commit e710d2c

Browse files
author
Yibing Liu
committed
Merge branch 'develop' of upstream into argsort_dev
2 parents a523b6f + d734595 commit e710d2c

26 files changed

+830
-108
lines changed
Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,32 @@
11
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
2-
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc
3-
DEPS paddle_fluid)
2+
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
3+
fluid_to_data_flow_graph_pass.cc
4+
data_flow_graph_to_fluid_pass.cc
5+
tensorrt_subgraph_pass.cc
6+
dfg_graphviz_draw_pass.cc
7+
DEPS framework_proto)
48
cc_test(test_node SRCS node_tester.cc DEPS analysis)
59
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
610

711
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
812

9-
cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid
10-
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
11-
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec)
13+
function (inference_analysis_test TARGET)
14+
set(options "")
15+
set(oneValueArgs "")
16+
set(multiValueArgs SRCS)
17+
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
1218

13-
cc_test(test_subgraph_splitter
14-
SRCS subgraph_splitter_tester.cc
15-
DEPS analysis paddle_fluid tensor
16-
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
17-
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec)
19+
cc_test(${TARGET}
20+
SRCS "${analysis_test_SRCS}"
21+
DEPS analysis
22+
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5)
23+
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
24+
endfunction(inference_analysis_test)
1825

19-
cc_test(test_dfg_graphviz_draw_pass
20-
SRCS dfg_graphviz_draw_pass_tester.cc
21-
DEPS analysis
22-
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
23-
set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec)
26+
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
27+
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
28+
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
29+
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
30+
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
31+
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
32+
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/inference/analysis/argument.h"
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/*
16+
* This file defines the class Argument, which is the input and output of the
17+
* analysis module. All the fields that needed either by Passes or PassManagers
18+
* are contained in Argument.
19+
*
20+
* TODO(Superjomn) Find some way better to contain the fields when it grow too
21+
* big.
22+
*/
23+
24+
#include "paddle/fluid/framework/program_desc.h"
25+
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
26+
27+
namespace paddle {
28+
namespace inference {
29+
namespace analysis {
30+
31+
/*
32+
* The argument definition of both Pass and PassManagers.
33+
*
34+
* All the fields should be registered here for clearness.
35+
*/
36+
struct Argument {
37+
// The graph that process by the Passes or PassManagers.
38+
std::unique_ptr<DataFlowGraph> main_dfg;
39+
40+
// The original program desc.
41+
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc;
42+
};
43+
44+
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
45+
#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \
46+
if (!UNLIKELY(field__)) { \
47+
LOG(ERROR) << "field " << #field__ << " should be set."; \
48+
return false; \
49+
}
50+
51+
} // namespace analysis
52+
} // namespace inference
53+
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph.cc

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
1616
#include "paddle/fluid/inference/analysis/dot.h"
17+
#include "paddle/fluid/inference/analysis/node.h"
1718

1819
namespace paddle {
1920
namespace inference {
@@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const {
5758
// Add nodes
5859
for (size_t i = 0; i < nodes.size(); i++) {
5960
const Node &node = nodes.Get(i);
60-
switch (node.type()) {
61-
case Node::Type::kValue:
62-
dot.AddNode(node.repr(), node.dot_attrs());
63-
break;
64-
case Node::Type::kFunction:
65-
dot.AddNode(node.repr(), node.dot_attrs());
66-
break;
67-
case Node::Type::kFunctionBlock:
68-
dot.AddNode(node.repr(), node.dot_attrs());
69-
break;
70-
default:
71-
PADDLE_THROW("unsupported Node type %d", static_cast<int>(node.type()));
72-
}
61+
dot.AddNode(node.repr(), node.dot_attrs());
7362
}
7463

7564
// Add edges
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
16+
#include "paddle/fluid/framework/proto_desc.h"
17+
18+
namespace paddle {
19+
namespace inference {
20+
namespace analysis {
21+
22+
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
23+
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
24+
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
25+
desc_ = argument->origin_program_desc.get();
26+
// Here some logic from program_desc.cc and will not add new interfaces into
27+
// framework::ProgramDesc class, use some UT to assure the correctness.
28+
auto* block = desc_->mutable_blocks()->Add();
29+
block->set_idx(framework::kRootBlockIndex);
30+
block->set_parent_idx(framework::kNoneBlockIndex);
31+
return true;
32+
}
33+
34+
bool DataFlowGraphToFluidPass::Finalize() { return true; }
35+
36+
void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
37+
auto traits = GraphTraits<DataFlowGraph>(graph);
38+
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) {
39+
if (it->deleted()) continue;
40+
switch (it->type()) {
41+
case Node::Type::kFunction:
42+
LOG(INFO) << "add function " << it->name();
43+
AddFluidOp(&(*it));
44+
break;
45+
case Node::Type::kFunctionBlock:
46+
AddEngineOp(&(*it));
47+
break;
48+
default:
49+
continue;
50+
}
51+
}
52+
}
53+
54+
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
55+
LOG(INFO) << "processing func " << node->name();
56+
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc());
57+
// currently only the main block is analyzed.
58+
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
59+
auto* op = main_block->add_ops();
60+
LOG(INFO) << "to copy the op";
61+
*op = *ori_op; // copy the attributes, by default, these will not be changed
62+
// by analysis phrase.
63+
// The inputs and outputs of the existing ops are not changed by tensorrt
64+
// subgraph pass.
65+
// NOTE It might be changed by other passes in the long run.
66+
}
67+
68+
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) {
69+
// auto* ori_op = static_cast<framework::proto::OpDesc*>(node->extra_info());
70+
// auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
71+
// auto* op = main_block->add_ops();
72+
// TODO(Superjomn) Here need to expose some arguments for default setting.
73+
}
74+
75+
} // namespace analysis
76+
} // namespace inference
77+
} // namespace paddle
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
/*
16+
* This file implements the transformation from fluid ProgramDesc to data flow
17+
* graph.
18+
*/
19+
20+
#pragma once
21+
22+
#include "paddle/fluid/framework/program_desc.h"
23+
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
24+
#include "paddle/fluid/inference/analysis/pass.h"
25+
26+
namespace paddle {
27+
namespace inference {
28+
namespace analysis {
29+
class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
30+
public:
31+
DataFlowGraphToFluidPass() = default;
32+
33+
bool Initialize(Argument *argument) override;
34+
bool Finalize() override;
35+
36+
void Run(DataFlowGraph *graph) override;
37+
38+
std::string repr() const override { return "DFG to fluid"; }
39+
std::string description() const override {
40+
return "Transform a DFG to a Fluid ProgramDesc";
41+
}
42+
43+
Pass *CreatePrinterPass(std::ostream &os,
44+
const std::string &banner) const override {
45+
return nullptr;
46+
}
47+
48+
protected:
49+
// Add a Fluid Op into the ProgramDesc.
50+
void AddFluidOp(Node *node);
51+
// Add a EngineOp into the ProgramDesc.
52+
void AddEngineOp(Node *node);
53+
54+
private:
55+
framework::proto::ProgramDesc *desc_;
56+
};
57+
} // namespace analysis
58+
} // namespace inference
59+
} // namespace paddle

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ namespace inference {
2727
namespace analysis {
2828

2929
TEST_F(DFG_Tester, Test) {
30-
framework::proto::ProgramDesc new_desc;
3130
DataFlowGraph graph;
3231

3332
FluidToDataFlowGraphPass pass0;
3433
DataFlowGraphToFluidPass pass1;
35-
pass0.Initialize(desc);
36-
pass1.Initialize(&new_desc);
34+
ASSERT_TRUE(pass0.Initialize(&argument));
35+
ASSERT_TRUE(pass1.Initialize(&argument));
3736

3837
pass0.Run(&graph);
3938
pass1.Run(&graph);
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace analysis {
20+
21+
void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
22+
auto content = Draw(graph);
23+
std::ofstream file(GenDotPath());
24+
file.write(content.c_str(), content.size());
25+
file.close();
26+
LOG(INFO) << "draw dot to " << GenDotPath();
27+
}
28+
29+
std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
30+
Dot dot;
31+
// Add nodes
32+
for (size_t i = 0; i < graph->nodes.size(); i++) {
33+
const Node &node = graph->nodes.Get(i);
34+
if (config_.display_deleted_node || !node.deleted()) {
35+
dot.AddNode(node.repr(), node.dot_attrs());
36+
}
37+
}
38+
// Add edges
39+
for (size_t i = 0; i < graph->nodes.size(); i++) {
40+
const Node &node = graph->nodes.Get(i);
41+
if (!config_.display_deleted_node && node.deleted()) continue;
42+
for (auto &in : node.inlinks) {
43+
if (!config_.display_deleted_node && in->deleted()) continue;
44+
for (auto &in : node.inlinks) {
45+
dot.AddEdge(in->repr(), node.repr(), {});
46+
}
47+
}
48+
}
49+
return dot.Build();
50+
}
51+
52+
} // namespace analysis
53+
} // namespace inference
54+
} // namespace paddle

0 commit comments

Comments
 (0)