Skip to content

Commit 59db5ef

Browse files
committed
add tensorrt ut and refine interface.
test=release/1.0.0
1 parent 644bad1 commit 59db5ef

15 files changed

+180
-27
lines changed

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@ TEST(Analyzer, analysis_without_tensorrt) {
3737
TEST(Analyzer, analysis_with_tensorrt) {
3838
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
3939
Argument argument;
40+
argument.Set<int>("minimum_subgraph_size", new int(0));
41+
argument.Set<int>("max_batch_size", new int(3));
42+
argument.Set<int>("workspace_size", new int(1 << 20));
43+
argument.Set<std::string>("precision_mode", new std::string("FP32"));
4044
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
4145
Analyzer analyser;
4246
analyser.Run(&argument);
4347
}
4448

45-
void TestWord2vecPrediction(const std::string &model_path) {
49+
void TestWord2vecPrediction(const std::string& model_path) {
4650
NativeConfig config;
4751
config.model_dir = model_path;
4852
config.use_gpu = false;
@@ -73,8 +77,8 @@ void TestWord2vecPrediction(const std::string &model_path) {
7377
// The outputs' buffers are in CPU memory.
7478
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
7579
LOG(INFO) << "data: "
76-
<< static_cast<float *>(outputs.front().data.data())[i];
77-
PADDLE_ENFORCE(static_cast<float *>(outputs.front().data.data())[i],
80+
<< static_cast<float*>(outputs.front().data.data())[i];
81+
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
7882
result[i]);
7983
}
8084
}

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
9797
}
9898
}
9999

100-
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
100+
void CreateTrtEngineOp(Node *node, Argument *argument,
101101
framework::proto::BlockDesc *block) {
102+
PADDLE_ENFORCE(argument->main_dfg.get());
103+
const DataFlowGraph &graph = *(argument->main_dfg);
102104
static int counter{0};
103105
PADDLE_ENFORCE(node->IsFunctionBlock());
104106
framework::OpDesc desc;
@@ -204,7 +206,10 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
204206

205207
PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc");
206208
// Set attrs
209+
207210
SetAttr(desc.Proto(), "subgraph", block->SerializeAsString());
211+
SetAttr(desc.Proto(), "max_batch_size", argument->Get<int>("max_batch_size"));
212+
SetAttr(desc.Proto(), "workspace_size", argument->Get<int>("workspace_size"));
208213
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
209214
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
210215
SetAttr(desc.Proto(), "output_name_mapping", output_mapping);
@@ -248,7 +253,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
248253
*block_desc.Proto()->mutable_vars() =
249254
argument_->origin_program_desc->blocks(0).vars();
250255
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
251-
CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto());
256+
CreateTrtEngineOp(node, argument_, block_desc.Proto());
252257
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
253258
auto *op = main_block->add_ops();
254259
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");

paddle/fluid/inference/analysis/subgraph_splitter.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
309309
void SubGraphFuse::ReplaceNodesWithSubGraphs() {
310310
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
311311
for (auto &subgraph : subgraphs) {
312+
if (subgraph.size() <= argument_->Get<int>("minimum_subgraph_size"))
313+
continue;
312314
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
313315
// replace this sub-graph with the first node. Two steps: 1. Create a Block
314316
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph

paddle/fluid/inference/analysis/subgraph_splitter.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020

2121
#include <vector>
2222

23+
#include "paddle/fluid/inference/analysis/argument.h"
2324
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
2425
#include "paddle/fluid/inference/analysis/node.h"
2526

@@ -63,8 +64,11 @@ class SubGraphFuse {
6364
public:
6465
using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller;
6566

66-
SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller)
67-
: graph_(graph), node_inside_subgraph_teller_(teller) {}
67+
SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller,
68+
Argument *argument)
69+
: graph_(graph),
70+
node_inside_subgraph_teller_(teller),
71+
argument_(argument) {}
6872

6973
// The main method which run all the logic.
7074
void operator()();
@@ -76,6 +80,7 @@ class SubGraphFuse {
7680
private:
7781
DataFlowGraph *graph_;
7882
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
83+
Argument *argument_;
7984
};
8085

8186
} // namespace analysis

paddle/fluid/inference/analysis/subgraph_splitter_tester.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ TEST(SubGraphSplitter, Split) {
6666
TEST(SubGraphSplitter, Fuse) {
6767
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
6868
auto dfg = ProgramDescToDFG(desc);
69+
Argument argument;
70+
argument.Set<int>("minimum_subgraph_size", new int(3));
6971

7072
size_t count0 = dfg.nodes.size();
7173

72-
SubGraphFuse fuse(&dfg, teller);
74+
SubGraphFuse fuse(&dfg, teller, &argument);
7375
fuse();
7476

7577
int count1 = 0;

paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass(
2424
: node_inside_subgraph_teller_(teller) {}
2525

2626
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
27-
SubGraphFuse(graph, node_inside_subgraph_teller_)();
27+
SubGraphFuse(graph, node_inside_subgraph_teller_, argument_)();
2828
VLOG(4) << "debug info "
2929
<< graph->HumanReadableInfo(false /*show_values*/,
3030
true /*show_functions*/);

paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
3333

3434
explicit TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller);
3535

36-
bool Initialize(Argument* argument) override { return true; }
36+
bool Initialize(Argument* argument) override {
37+
argument_ = argument;
38+
return true;
39+
}
3740

3841
// This class get a sub-graph as input and determine whether to transform this
3942
// sub-graph into TensorRT.
@@ -46,6 +49,7 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
4649

4750
private:
4851
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
52+
Argument* argument_;
4953
};
5054

5155
} // namespace analysis

paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ TEST(TensorRTSubGraphPass, main) {
3636
};
3737

3838
Argument argument(FLAGS_inference_model_dir);
39+
argument.Set<int>("minimum_subgraph_size", new int(0));
40+
argument.Set<int>("max_batch_size", new int(3));
41+
argument.Set<int>("workspace_size", new int(1 << 20));
42+
argument.Set<std::string>("precision_mode", new std::string("FP32"));
3943

4044
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
4145
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};

paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
3535
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
3636
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
3737
VLOG(3) << "Predictor::init()";
38-
FLAGS_tensorrt_max_batch_size = config_.max_batch_size;
39-
FLAGS_tensorrt_workspace_size = config_.workspace_size;
4038
if (config_.use_gpu) {
4139
place_ = paddle::platform::CUDAPlace(config_.device);
4240
} else {
@@ -92,6 +90,14 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
9290
void OptimizeInferenceProgram() {
9391
// Analyze inference_program
9492
Argument argument;
93+
94+
argument.Set<int>("minimum_subgraph_size",
95+
new int(config_.minimum_subgraph_size));
96+
argument.Set<int>("max_batch_size", new int(config_.max_batch_size));
97+
argument.Set<int>("workspace_size", new int(config_.workspace_size));
98+
argument.Set<std::string>("precision_mode",
99+
new std::string(config_.precision_mode));
100+
95101
if (!config_.model_dir.empty()) {
96102
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
97103
} else {

paddle/fluid/inference/api/paddle_inference_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ struct MixedRTConfig : public NativeConfig {
194194
// For workspace_size, refer it from here:
195195
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
196196
int workspace_size{1 << 30};
197+
// We transform the Ops that can be converted into TRT layer in the model,
198+
// and aggregate these Ops into subgraphs for TRT execution.
199+
// We set this variable to control the minimum number of nodes in the
200+
// subgraph, 3 as default value.
201+
int minimum_subgraph_size = 3;
202+
// Reserved configuration
203+
// We just support "FP32" now, "FP16" and "INT8" will be supported.
204+
std::string precision_mode = "FP32";
197205
};
198206

199207
// NOTE WIP, not stable yet.

0 commit comments

Comments
 (0)