Skip to content

Commit 5082642

Browse files
authored
feature/analysis to support sub-graph for TRT engine (#11538)
1 parent bc28cf6 commit 5082642

39 files changed

+1015
-189
lines changed

paddle/contrib/inference/CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ if(APPLE)
1818
endif(APPLE)
1919

2020

21-
set(inference_deps paddle_inference_api paddle_fluid_api)
21+
set(inference_deps paddle_inference_api paddle_fluid_api paddle_inference_tensorrt_subgraph_engine)
2222

2323
function(inference_api_test TARGET_NAME)
2424
if (WITH_TESTING)
@@ -50,6 +50,14 @@ cc_test(test_paddle_inference_api
5050
inference_api_test(test_paddle_inference_api_impl
5151
ARGS test_word2vec test_image_classification)
5252

53+
if(WITH_GPU AND TENSORRT_FOUND)
54+
cc_library(paddle_inference_tensorrt_subgraph_engine
55+
SRCS paddle_inference_api_tensorrt_subgraph_engine.cc
56+
DEPS paddle_inference_api analysis tensorrt_engine paddle_inference_api paddle_fluid_api)
57+
58+
inference_api_test(test_paddle_inference_api_tensorrt_subgraph_engine ARGS test_word2vec)
59+
endif()
60+
5361
if (WITH_ANAKIN AND WITH_TESTING) # only needed in CI
5462
# Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's,
5563
# so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to

paddle/contrib/inference/paddle_inference_api.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ struct PaddleTensor {
7373
};
7474

7575
enum class PaddleEngineKind {
76-
kNative = 0, // Use the native Fluid facility.
77-
kAnakin, // Use Anakin for inference.
76+
kNative = 0, // Use the native Fluid facility.
77+
kAnakin, // Use Anakin for inference.
78+
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
7879
// TODO(Superjomn) support following engines latter.
7980
// kTensorRT, // Use TensorRT for inference.
8081
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
81-
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
8282
};
8383

8484
/*
@@ -130,6 +130,11 @@ struct AnakinConfig : public PaddlePredictor::Config {
130130
int max_batch_size{-1};
131131
};
132132

133+
struct TensorRTConfig : public NativeConfig {
134+
// Determine whether a subgraph will be executed by TRT.
135+
int min_subgraph_size{1};
136+
};
137+
133138
// A factory to help create different predictors.
134139
//
135140
// FOR EXTENSION DEVELOPER:

paddle/contrib/inference/paddle_inference_api_impl.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ bool NativePaddlePredictor::Init(
8989
LOG(ERROR) << "fail to load inference model.";
9090
return false;
9191
}
92+
9293
ctx_ = executor_->Prepare(*inference_program_, 0);
9394
executor_->CreateVariables(
9495
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
@@ -119,6 +120,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
119120
return false;
120121
}
121122
for (size_t i = 0; i < feed_target_names_.size(); ++i) {
123+
VLOG(4) << "setting " << i << "-th target";
122124
feed_targets[feed_target_names_[i]] = &feeds[i];
123125
}
124126
// get fetch variable
@@ -130,14 +132,16 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
130132
}
131133
// Run the inference program
132134
// if share variables, we need not create variables
135+
VLOG(4) << "Run prepared context";
133136
executor_->RunPreparedContext(
134137
ctx_.get(),
135138
sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
136139
&feed_targets,
137140
&fetch_targets,
138141
false /* don't create variable eatch time */);
142+
VLOG(4) << "Finish prepared context";
139143
if (!GetFetch(fetchs, output_data)) {
140-
LOG(ERROR) << "fail to get fetchs";
144+
LOG(ERROR) << "fail to get fetches";
141145
return false;
142146
}
143147
VLOG(3) << "predict cost: " << timer.toc() << "ms";

paddle/contrib/inference/paddle_inference_api_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class NativePaddlePredictor : public PaddlePredictor {
4444

4545
~NativePaddlePredictor() override;
4646

47-
private:
47+
protected:
4848
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
4949
std::vector<framework::LoDTensor> *feeds);
5050
bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/contrib/inference/paddle_inference_api.h"
16+
#include "paddle/contrib/inference/paddle_inference_api_impl.h"
17+
#include "paddle/fluid/inference/analysis/analyzer.h"
18+
#include "paddle/fluid/inference/utils/singleton.h"
19+
20+
namespace paddle {
21+
22+
using inference::analysis::Argument;
23+
using inference::Singleton;
24+
using inference::analysis::Analyzer;
25+
using framework::proto::ProgramDesc;
26+
27+
class TensorRTSubgraphPredictor : public NativePaddlePredictor {
28+
public:
29+
explicit TensorRTSubgraphPredictor(const TensorRTConfig& config)
30+
: NativePaddlePredictor(config), config_(config) {}
31+
32+
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
33+
VLOG(3) << "Predictor::init()";
34+
35+
if (config_.use_gpu) {
36+
place_ = paddle::platform::CUDAPlace(config_.device);
37+
} else {
38+
place_ = paddle::platform::CPUPlace();
39+
}
40+
if (parent_scope) {
41+
scope_ = parent_scope;
42+
sub_scope_ = &(parent_scope->NewScope());
43+
} else {
44+
paddle::framework::InitDevices(false);
45+
scope_.reset(new paddle::framework::Scope());
46+
}
47+
48+
executor_.reset(new paddle::framework::Executor(place_));
49+
50+
// Initialize the inference program
51+
if (!config_.model_dir.empty()) {
52+
// Parameters are saved in separate files sited in
53+
// the specified `dirname`.
54+
inference_program_ = paddle::inference::Load(
55+
executor_.get(), scope_.get(), config_.model_dir);
56+
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
57+
// All parameters are saved in a single file.
58+
// The file names should be consistent with that used
59+
// in Python API `fluid.io.save_inference_model`.
60+
inference_program_ = paddle::inference::Load(
61+
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
62+
} else {
63+
LOG(ERROR) << "fail to load inference model.";
64+
return false;
65+
}
66+
67+
// Analyze inference_program
68+
Argument argument;
69+
argument.origin_program_desc.reset(
70+
new ProgramDesc(*inference_program_->Proto()));
71+
Singleton<Analyzer>::Global().Run(&argument);
72+
CHECK(argument.transformed_program_desc);
73+
VLOG(5) << "transformed program:\n"
74+
<< argument.transformed_program_desc->SerializeAsString();
75+
VLOG(5) << "to prepare executor";
76+
*inference_program_->Proto() = *argument.transformed_program_desc;
77+
ctx_ = executor_->Prepare(*inference_program_, 0);
78+
79+
VLOG(5) << "to create variables";
80+
executor_->CreateVariables(
81+
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
82+
83+
// Get the feed_target_names and fetch_target_names
84+
feed_target_names_ = inference_program_->GetFeedTargetNames();
85+
fetch_target_names_ = inference_program_->GetFetchTargetNames();
86+
return true;
87+
}
88+
89+
private:
90+
TensorRTConfig config_;
91+
};
92+
93+
template <>
94+
std::unique_ptr<PaddlePredictor>
95+
CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
96+
const TensorRTConfig& config) {
97+
VLOG(3) << "create TensorRTSubgraphPredictor";
98+
if (config.use_gpu) {
99+
// 1. GPU memeroy
100+
PADDLE_ENFORCE_GT(
101+
config.fraction_of_gpu_memory,
102+
0.f,
103+
"fraction_of_gpu_memory in the config should be set to range (0., 1.]");
104+
PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
105+
std::vector<std::string> flags;
106+
if (config.fraction_of_gpu_memory >= 0.0f ||
107+
config.fraction_of_gpu_memory <= 0.95f) {
108+
flags.push_back("dummpy");
109+
std::string flag = "--fraction_of_gpu_memory_to_use=" +
110+
std::to_string(config.fraction_of_gpu_memory);
111+
flags.push_back(flag);
112+
VLOG(3) << "set flag: " << flag;
113+
framework::InitGflags(flags);
114+
}
115+
}
116+
117+
std::unique_ptr<PaddlePredictor> predictor(
118+
new TensorRTSubgraphPredictor(config));
119+
if (!dynamic_cast<TensorRTSubgraphPredictor*>(predictor.get())
120+
->Init(nullptr)) {
121+
return nullptr;
122+
}
123+
return std::move(predictor);
124+
}
125+
126+
} // namespace paddle
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gflags/gflags.h>
16+
#include <glog/logging.h>
17+
#include <gtest/gtest.h>
18+
#include "paddle/contrib/inference/paddle_inference_api.h"
19+
20+
namespace paddle {
21+
22+
DEFINE_string(dirname, "", "Directory of the inference model.");
23+
24+
void Main(bool use_gpu) {
25+
//# 1. Create PaddlePredictor with a config.
26+
TensorRTConfig config;
27+
config.model_dir = FLAGS_dirname + "word2vec.inference.model";
28+
config.use_gpu = use_gpu;
29+
config.fraction_of_gpu_memory = 0.15;
30+
config.device = 0;
31+
auto predictor =
32+
CreatePaddlePredictor<TensorRTConfig,
33+
PaddleEngineKind::kAutoMixedTensorRT>(config);
34+
35+
for (int batch_id = 0; batch_id < 3; batch_id++) {
36+
//# 2. Prepare input.
37+
int64_t data[4] = {1, 2, 3, 4};
38+
39+
PaddleTensor tensor{.name = "",
40+
.shape = std::vector<int>({4, 1}),
41+
.data = PaddleBuf(data, sizeof(data)),
42+
.dtype = PaddleDType::INT64};
43+
44+
// For simplicity, we set all the slots with the same data.
45+
std::vector<PaddleTensor> slots(4, tensor);
46+
47+
//# 3. Run
48+
std::vector<PaddleTensor> outputs;
49+
CHECK(predictor->Run(slots, &outputs));
50+
51+
//# 4. Get output.
52+
ASSERT_EQ(outputs.size(), 1UL);
53+
LOG(INFO) << "output buffer size: " << outputs.front().data.length();
54+
const size_t num_elements = outputs.front().data.length() / sizeof(float);
55+
// The outputs' buffers are in CPU memory.
56+
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
57+
LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i];
58+
}
59+
}
60+
}
61+
62+
TEST(paddle_inference_api_tensorrt_subgraph_engine, main) { Main(true); }
63+
64+
} // namespace paddle

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
21
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
32
fluid_to_data_flow_graph_pass.cc
43
data_flow_graph_to_fluid_pass.cc
5-
tensorrt_subgraph_pass.cc
64
dfg_graphviz_draw_pass.cc
7-
DEPS framework_proto)
5+
tensorrt_subgraph_pass.cc
6+
tensorrt_subgraph_node_mark_pass.cc
7+
analyzer.cc
8+
helper.cc
9+
DEPS framework_proto proto_desc)
810
cc_test(test_node SRCS node_tester.cc DEPS analysis)
911
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
1012

@@ -28,5 +30,7 @@ inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_
2830
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
2931
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
3032
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
31-
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
33+
inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
3234
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
35+
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
36+
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/inference/analysis/analyzer.h"
16+
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
17+
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
18+
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
19+
#include "paddle/fluid/inference/analysis/pass_manager.h"
20+
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
21+
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
22+
23+
namespace paddle {
24+
namespace inference {
25+
namespace analysis {
26+
27+
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
28+
"Enable subgraph to TensorRT engine for acceleration");
29+
30+
DEFINE_string(inference_analysis_graphviz_log_root, "./",
31+
"Graphviz debuger for data flow graphs.");
32+
33+
class DfgPassManagerImpl final : public DfgPassManager {
34+
public:
35+
DfgPassManagerImpl() {
36+
// TODO(Superjomn) set the key with pass reprs.
37+
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
38+
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) {
39+
auto trt_teller = [](const Node* node) {
40+
if (!node->IsFunction()) return false;
41+
return static_cast<const Function*>(node)->func_type() == "mul";
42+
};
43+
AddPass("tensorrt-subgraph-marker",
44+
new TensorRTSubgraphNodeMarkPass(trt_teller));
45+
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
46+
}
47+
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
48+
}
49+
50+
std::string repr() const override { return "dfg-pass-manager"; }
51+
std::string description() const override { return "DFG pass manager."; }
52+
53+
private:
54+
void AddPass(const std::string& name, Pass* pass) {
55+
LOG(INFO) << "Adding pass " << name;
56+
Register(name, pass);
57+
AddGraphvizDebugerPass(pass);
58+
}
59+
60+
// Add the graphviz debuger pass if the parent pass has one.
61+
void AddGraphvizDebugerPass(Pass* pass) {
62+
auto* debuger_pass = pass->CreateGraphvizDebugerPass();
63+
if (debuger_pass) {
64+
LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]";
65+
Register(debuger_pass->repr(), debuger_pass);
66+
}
67+
}
68+
};
69+
70+
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
71+
72+
void Analyzer::Run(Argument* argument) {
73+
for (auto& x : data_) {
74+
PADDLE_ENFORCE(x->Initialize(argument));
75+
x->RunAll();
76+
PADDLE_ENFORCE(x->Finalize());
77+
}
78+
}
79+
80+
} // namespace analysis
81+
} // namespace inference
82+
} // namespace paddle

0 commit comments

Comments
 (0)