Skip to content

Commit 9df2d8b

Browse files
authored
test/add text-classification test (#13081)
1 parent 4907d09 commit 9df2d8b

17 files changed

+276
-61
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
8686
}
8787
op_desc.SetInput("Bias", {new_bias_var});
8888
}
89-
9089
#undef GET_NODE
9190

91+
// Create temp variables.
92+
scope->Var(name_scope + "/BatchedInput.new")
93+
->GetMutable<framework::LoDTensor>();
94+
scope->Var(name_scope + "/BatchCellPreAct.new")
95+
->GetMutable<framework::LoDTensor>();
96+
scope->Var(name_scope + "/BatchedGate.new")
97+
->GetMutable<framework::LoDTensor>();
98+
9299
op_desc.SetInput("H0", {});
93100
op_desc.SetInput("C0", {});
94101
op_desc.SetOutput("Hidden", {hidden_n->Name()});
95102
op_desc.SetOutput("Cell", {cell_n->Name()});
96103
op_desc.SetOutput("XX", {xx_n->Name()});
97-
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
104+
op_desc.SetOutput("BatchedGate", {name_scope + "/BatchedGate.new"});
105+
op_desc.SetOutput("BatchCellPreAct", {name_scope + "/BatchCellPreAct.new"});
106+
op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"});
98107
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
99108
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
100109
// TODO(TJ): get from attr
@@ -130,8 +139,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
130139

131140
int fusion_count{0};
132141

133-
auto fc_no_bias_handler = [&](
134-
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
142+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
143+
Graph* g) {
135144
#define GET_NODE(name__) \
136145
std::string name__##key = name_scope + "/" + #name__; \
137146
auto* name__##n = pattern->RetrieveNode(name__##key); \
@@ -152,21 +161,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
152161

153162
if (with_fc_bias) {
154163
GET_NODE(fc_bias);
164+
GET_NODE(elementwise_add);
155165
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias);
166+
// Remove unneeded nodes.
167+
std::unordered_set<const Node*> marked_nodes(
168+
{mul_n, lstm_n, elementwise_add_n});
169+
GraphSafeRemoveNodes(graph, marked_nodes);
156170
} else {
157171
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1);
172+
// Remove unneeded nodes.
173+
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n});
174+
GraphSafeRemoveNodes(graph, marked_nodes);
158175
}
159176
#undef GET_NODE
160177

161-
// Remove unneeded nodes.
162-
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n});
163-
164-
GraphSafeRemoveNodes(graph, marked_nodes);
165-
166178
++fusion_count;
167179
};
168180

169-
gpd(graph, fc_no_bias_handler);
181+
gpd(graph, handler);
170182

171183
return fusion_count;
172184
}

paddle/fluid/framework/ir/fc_lstm_fuse_pass.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#pragma once
16+
1517
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1618
#include "paddle/fluid/framework/ir/graph.h"
1719
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
7373
void GraphPatternDetector::operator()(Graph* graph,
7474
GraphPatternDetector::handle_t handler) {
7575
if (!MarkPDNodesInGraph(*graph)) {
76-
LOG(INFO) << "Mark failed";
7776
return;
7877
}
7978

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#endif
2020

2121
#include <numeric>
22+
#include <string>
23+
#include <utility>
24+
#include <vector>
2225
#include "paddle/fluid/framework/ir/graph.h"
2326
#include "paddle/fluid/framework/ir/node.h"
2427
#include "paddle/fluid/inference/analysis/dot.h"

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ endif()
5858
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
5959
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
6060
ARGS --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
61-
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
61+
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
6262

6363
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
6464
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
@@ -74,7 +74,7 @@ inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
7474
set(CHINESE_NER_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner_model.tar.gz")
7575
set(CHINESE_NER_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner-data.txt.tar.gz")
7676
set(CHINESE_NER_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/chinese_ner" CACHE PATH "Chinese ner model and data root." FORCE)
77-
if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING)
77+
if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
7878
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_MODEL_URL} "chinese_ner_model.tar.gz")
7979
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_DATA_URL} "chinese_ner-data.txt.tar.gz")
8080
endif()
@@ -87,7 +87,7 @@ inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc
8787
set(LAC_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/lac_model.tar.gz")
8888
set(LAC_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/lac_data.txt.tar.gz")
8989
set(LAC_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/lac" CACHE PATH "LAC model and data root." FORCE)
90-
if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING)
90+
if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
9191
inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_MODEL_URL} "lac_model.tar.gz")
9292
inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_DATA_URL} "lac_data.txt.tar.gz")
9393
endif()
@@ -96,3 +96,15 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc
9696
EXTRA_DEPS paddle_inference_api paddle_fluid_api
9797
ARGS --infer_model=${LAC_INSTALL_DIR}/model
9898
--infer_data=${LAC_INSTALL_DIR}/data.txt)
99+
100+
101+
set(TEXT_CLASSIFICATION_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz")
102+
set(TEXT_CLASSIFICATION_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/text_classification" CACHE PATH "Text Classification model and data root." FORCE)
103+
104+
if (NOT EXISTS ${TEXT_CLASSIFICATION_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
105+
inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_MODEL_URL} "text-classification-Senta.tar.gz")
106+
endif()
107+
108+
inference_analysis_test(test_text_classification SRCS test_text_classification.cc
109+
EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor
110+
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta)

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/inference/analysis/analyzer.h"
1616
#include <string>
17+
#include <vector>
1718
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
1819
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
1920
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
@@ -41,20 +42,16 @@ class DfgPassManagerImpl final : public DfgPassManager {
4142
public:
4243
DfgPassManagerImpl() {
4344
// TODO(Superjomn) set the key with pass reprs.
44-
LOG(INFO)
45-
<< "-----------------------------------------------------------------";
46-
if (FLAGS_IA_enable_ir) {
47-
AddPass("fluid-to-ir-pass", new FluidToIrPass);
48-
} else {
45+
if (!FLAGS_IA_enable_ir) {
4946
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
47+
} else {
48+
AddPass("fluid-to-ir-pass", new FluidToIrPass);
5049
}
5150
TryAddTensorRtPass();
5251
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
5352
if (!FLAGS_IA_output_storage_path.empty()) {
5453
AddPass("model-store-pass", new ModelStorePass);
5554
}
56-
LOG(INFO)
57-
<< "-----------------------------------------------------------------";
5855
}
5956

6057
std::string repr() const override { return "dfg-pass-manager"; }
@@ -101,19 +98,16 @@ class DfgPassManagerImpl final : public DfgPassManager {
10198
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
10299

103100
void Analyzer::Run(Argument* argument) {
101+
std::vector<std::string> passes;
102+
for (auto& pass : all_ir_passes_) {
103+
if (!disabled_ir_passes_.count(pass)) {
104+
passes.push_back(pass);
105+
passes.push_back("graph_viz_pass"); // add graphviz for debug.
106+
}
107+
}
108+
passes.push_back("graph_viz_pass");
104109
// Ugly support fluid-to-ir-pass
105-
argument->Set(kFluidToIrPassesAttr,
106-
new std::vector<std::string>({
107-
// Manual update the passes here.
108-
"graph_viz_pass", //
109-
"infer_clean_graph_pass", "graph_viz_pass", //
110-
"attention_lstm_fuse_pass", "graph_viz_pass", //
111-
"fc_lstm_fuse_pass", "graph_viz_pass", //
112-
"mul_lstm_fuse_pass", "graph_viz_pass", //
113-
"seq_concat_fc_fuse_pass", "graph_viz_pass", //
114-
"fc_fuse_pass", "graph_viz_pass" //
115-
116-
}));
110+
argument->Set(kFluidToIrPassesAttr, new std::vector<std::string>(passes));
117111

118112
for (auto& x : data_) {
119113
PADDLE_ENFORCE(x->Initialize(argument));
@@ -122,6 +116,11 @@ void Analyzer::Run(Argument* argument) {
122116
}
123117
}
124118

119+
Analyzer& Analyzer::DisableIrPasses(const std::vector<std::string>& passes) {
120+
disabled_ir_passes_.insert(passes.begin(), passes.end());
121+
return *this;
122+
}
123+
125124
} // namespace analysis
126125
} // namespace inference
127126
} // namespace paddle

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,10 @@ limitations under the License. */
3636
*/
3737

3838
#include <gflags/gflags.h>
39+
#include "paddle/fluid/inference/analysis/flags.h"
3940
#include "paddle/fluid/inference/analysis/pass.h"
4041
#include "paddle/fluid/inference/analysis/pass_manager.h"
4142

42-
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
43-
// flag if not available.
44-
DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
45-
DECLARE_string(IA_graphviz_log_root);
46-
DECLARE_string(IA_output_storage_path);
47-
DECLARE_bool(IA_enable_ir);
48-
4943
namespace paddle {
5044
namespace inference {
5145
namespace analysis {
@@ -57,7 +51,26 @@ class Analyzer : public OrderedRegistry<PassManager> {
5751

5852
void Run(Argument* argument);
5953

54+
Analyzer& DisableIrPasses(const std::vector<std::string>& passes);
55+
6056
DISABLE_COPY_AND_ASSIGN(Analyzer);
57+
58+
private:
59+
// All avaiable IR passes.
60+
// The bigger fuse comes first, so that the small operators prefer to be
61+
// merged in a larger fuse op. The small fusion will not break the pattern of
62+
// larger fusion.
63+
const std::vector<std::string> all_ir_passes_{{
64+
// Manual update the passes here.
65+
"infer_clean_graph_pass", //
66+
"attention_lstm_fuse_pass", //
67+
"fc_lstm_fuse_pass", //
68+
"mul_lstm_fuse_pass", //
69+
"seq_concat_fc_fuse_pass", //
70+
"fc_fuse_pass", //
71+
}};
72+
73+
std::unordered_set<std::string> disabled_ir_passes_;
6174
};
6275

6376
} // namespace analysis

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,17 +271,22 @@ void TestDituRNNPrediction(const std::string &model_path,
271271
const std::string &data_path, int batch_size,
272272
bool use_analysis, bool activate_ir,
273273
int num_times = 1) {
274-
NativeConfig config;
274+
AnalysisConfig config;
275275
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
276276
config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
277277
config.use_gpu = false;
278278
config.device = 0;
279279
config.specify_input_name = true;
280+
config.enable_ir_optim = activate_ir;
281+
PADDLE_ENFORCE(config.ir_mode ==
282+
AnalysisConfig::IrPassMode::kExclude); // default
283+
config.ir_passes.clear(); // Do not exclude any pass.
280284

281285
auto base_predictor =
282286
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
283287
auto predictor =
284-
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(config);
288+
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
289+
config);
285290
std::vector<PaddleTensor> input_slots;
286291
DataRecord data(data_path, batch_size);
287292
// Prepare inputs.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gflags/gflags.h>
16+
17+
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
18+
// flag if not available.
19+
DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
20+
DECLARE_string(IA_graphviz_log_root);
21+
DECLARE_string(IA_output_storage_path);
22+
DECLARE_bool(IA_enable_ir);

paddle/fluid/inference/analysis/fluid_to_ir_pass.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/inference/analysis/flags.h"
1819
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
1920
#include "paddle/fluid/inference/analysis/pass.h"
2021

@@ -85,9 +86,11 @@ class FluidToIrPass final : public DataFlowGraphPass {
8586
new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr)));
8687
}
8788

88-
const auto &ir_passes_to_apply =
89-
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
90-
ir_passes.Apply(ir_passes_to_apply);
89+
if (FLAGS_IA_enable_ir) {
90+
const auto &ir_passes_to_apply =
91+
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
92+
ir_passes.Apply(ir_passes_to_apply);
93+
}
9194

9295
PADDLE_ENFORCE(argument_->main_dfg.get());
9396
argument_->main_dfg->Build(ir_passes.graph());

0 commit comments

Comments
 (0)