Skip to content

Commit dcfbc6a

Browse files
authored
inference analyzer as bin (#12450)
1 parent 31a2c87 commit dcfbc6a

22 files changed

+309
-63
lines changed

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
66
tensorrt_subgraph_node_mark_pass.cc
77
analyzer.cc
88
helper.cc
9+
model_store_pass.cc
910
DEPS framework_proto proto_desc)
1011
cc_test(test_node SRCS node_tester.cc DEPS analysis)
1112
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
13+
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
1214

1315
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
1416

@@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_
4042
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
4143
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
4244
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
45+
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
1818
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
1919
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
20+
#include "paddle/fluid/inference/analysis/model_store_pass.h"
2021
#include "paddle/fluid/inference/analysis/pass_manager.h"
2122
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
2223
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
@@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
2930
DEFINE_string(inference_analysis_graphviz_log_root, "./",
3031
"Graphviz debuger for data flow graphs.");
3132

33+
DEFINE_string(inference_analysis_output_storage_path, "",
34+
"optimized model output path");
35+
3236
namespace inference {
3337
namespace analysis {
3438

@@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
4751
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
4852
}
4953
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
54+
if (!FLAGS_inference_analysis_output_storage_path.empty()) {
55+
AddPass("model-store-pass", new ModelStorePass);
56+
}
5057
}
5158

5259
std::string repr() const override { return "dfg-pass-manager"; }

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,23 @@ limitations under the License. */
1616

1717
/*
1818
* This file contains Analyzer, an class that exposed as a library that analyze
19-
* and optimize
20-
* Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to
21-
* control whether
22-
* an process is applied on the program.
19+
* and optimize Fluid ProgramDesc for inference. Similar to LLVM, it has
20+
* multiple flags to
21+
* control whether an process is applied on the program.
2322
*
2423
* The processes are called Passes in analysis, the Passes are placed in a
25-
* pipeline, the first
26-
* Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to
27-
* a data flow
28-
* graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow
29-
* graph to a
30-
* Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes
31-
* which take a
32-
* node or data flow graph as input.
24+
* pipeline, the first Pass is the FluidToDataFlowGraphPass which transforms a
25+
* Fluid ProgramDesc to
26+
* a data flow graph, the last Pass is DataFlowGraphToFluidPass which transforms
27+
* a data flow graph to a Fluid ProgramDesc. The passes in the middle of the
28+
* pipeline can be any Passes
29+
* which take a node or data flow graph as input.
3330
*
3431
* The Analyzer can be used in two methods, the first is a executable file which
35-
* can be used to
36-
* pre-process the inference model and can be controlled by passing difference
37-
* command flags;
32+
* can be used to pre-process the inference model and can be controlled by
33+
* passing difference command flags;
3834
* the other way is to compose inside the inference API as a runtime pre-process
39-
* phase in the
40-
* inference service.
35+
* phase in the inference service.
4136
*/
4237

4338
#include <gflags/gflags.h>
@@ -50,6 +45,7 @@ namespace paddle {
5045
// flag if not available.
5146
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
5247
DECLARE_string(inference_analysis_graphviz_log_root);
48+
DECLARE_string(inference_analysis_output_storage_path);
5349

5450
namespace inference {
5551
namespace analysis {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/*
16+
* This file implements analysizer -- an executation help to analyze and
17+
* optimize trained model.
18+
*/
19+
#include "paddle/fluid/inference/analysis/analyzer.h"
20+
#include <gflags/gflags.h>
21+
#include <glog/logging.h>
22+
23+
int main(int argc, char** argv) {
24+
google::ParseCommandLineFlags(&argc, &argv, true);
25+
using paddle::inference::analysis::Analyzer;
26+
using paddle::inference::analysis::Argument;
27+
28+
Argument argument;
29+
Analyzer analyzer;
30+
analyzer.Run(&argument);
31+
32+
return 0;
33+
}

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@ namespace paddle {
2020
namespace inference {
2121
namespace analysis {
2222

23-
TEST_F(DFG_Tester, analysis_without_tensorrt) {
23+
TEST(Analyzer, analysis_without_tensorrt) {
2424
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false;
25+
Argument argument;
26+
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
2527
Analyzer analyser;
2628
analyser.Run(&argument);
2729
}
2830

29-
TEST_F(DFG_Tester, analysis_with_tensorrt) {
31+
TEST(Analyzer, analysis_with_tensorrt) {
3032
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true;
33+
Argument argument;
34+
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
3135
Analyzer analyser;
3236
analyser.Run(&argument);
3337
}

paddle/fluid/inference/analysis/argument.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ namespace analysis {
3636
* All the fields should be registered here for clearness.
3737
*/
3838
struct Argument {
39+
Argument() = default;
40+
explicit Argument(const std::string& fluid_model_dir)
41+
: fluid_model_dir(new std::string(fluid_model_dir)) {}
42+
// The directory of the trained model.
43+
std::unique_ptr<std::string> fluid_model_dir;
44+
// The path of `__model__` and `param`, this is used when the file name of
45+
// model and param is changed.
46+
std::unique_ptr<std::string> fluid_model_program_path;
47+
std::unique_ptr<std::string> fluid_model_param_path;
48+
3949
// The graph that process by the Passes or PassManagers.
4050
std::unique_ptr<DataFlowGraph> main_dfg;
4151

@@ -44,6 +54,9 @@ struct Argument {
4454

4555
// The processed program desc.
4656
std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc;
57+
58+
// The output storage path of ModelStorePass.
59+
std::unique_ptr<std::string> model_output_store_path;
4760
};
4861

4962
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)

paddle/fluid/inference/analysis/data_flow_graph_tester.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace inference {
2020
namespace analysis {
2121

2222
TEST(DataFlowGraph, BFS) {
23-
auto desc = LoadProgramDesc();
23+
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
2424
auto dfg = ProgramDescToDFG(desc);
2525
dfg.Build();
2626

@@ -44,7 +44,7 @@ TEST(DataFlowGraph, BFS) {
4444
}
4545

4646
TEST(DataFlowGraph, DFS) {
47-
auto desc = LoadProgramDesc();
47+
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
4848
auto dfg = ProgramDescToDFG(desc);
4949
dfg.Build();
5050
GraphTraits<DataFlowGraph> trait(&dfg);

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ namespace paddle {
2626
namespace inference {
2727
namespace analysis {
2828

29-
TEST_F(DFG_Tester, Test) {
30-
DataFlowGraph graph;
29+
TEST(DataFlowGraph, Test) {
30+
Argument argument(FLAGS_inference_model_dir);
3131

3232
FluidToDataFlowGraphPass pass0;
3333
DataFlowGraphToFluidPass pass1;
3434
ASSERT_TRUE(pass0.Initialize(&argument));
3535
ASSERT_TRUE(pass1.Initialize(&argument));
3636

37-
pass0.Run(&graph);
38-
pass1.Run(&graph);
37+
pass0.Run(argument.main_dfg.get());
38+
pass1.Run(argument.main_dfg.get());
3939

4040
pass0.Finalize();
4141
pass1.Finalize();
4242

43-
LOG(INFO) << graph.nodes.size();
43+
LOG(INFO) << argument.main_dfg->nodes.size();
4444
}
4545

4646
}; // namespace analysis

paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@ namespace paddle {
2323
namespace inference {
2424
namespace analysis {
2525

26-
TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
27-
auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
26+
TEST(DFG_GraphvizDrawPass, dfg_graphviz_draw_pass_tester) {
27+
Argument argument(FLAGS_inference_model_dir);
28+
FluidToDataFlowGraphPass pass0;
29+
ASSERT_TRUE(pass0.Initialize(&argument));
30+
pass0.Run(argument.main_dfg.get());
31+
32+
// auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
33+
2834
DFG_GraphvizDrawPass::Config config("./", "test");
2935
DFG_GraphvizDrawPass pass(config);
3036
pass.Initialize(&argument);
31-
pass.Run(&dfg);
37+
pass.Run(argument.main_dfg.get());
3238

3339
// test content
3440
std::ifstream file("./0-graph_test.dot");

paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <glog/logging.h>
1516
#include <string>
1617
#include <vector>
1718

@@ -25,8 +26,20 @@ namespace analysis {
2526

2627
bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
2728
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
28-
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc);
29-
PADDLE_ENFORCE(argument);
29+
if (argument->origin_program_desc) {
30+
LOG(WARNING) << "argument's origin_program_desc is already set, might "
31+
"duplicate called";
32+
}
33+
if (!argument->fluid_model_program_path) {
34+
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir);
35+
argument->fluid_model_program_path.reset(
36+
new std::string(*argument->fluid_model_dir + "/__model__"));
37+
}
38+
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
39+
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
40+
argument->origin_program_desc.reset(
41+
new framework::proto::ProgramDesc(program));
42+
3043
if (!argument->main_dfg) {
3144
argument->main_dfg.reset(new DataFlowGraph);
3245
}

0 commit comments

Comments
 (0)