Skip to content

Commit 910cd41

Browse files
committed
- Disabled embedding_fc_lstm_fuse by defult and
extended test_text_classification ot use new op
1 parent d5114c6 commit 910cd41

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

paddle/fluid/inference/api/paddle_inference_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ struct AnalysisConfig : public NativeConfig {
216216
bool enable_ir_optim = true;
217217
// Manually determine the IR passes to run.
218218
IrPassMode ir_mode{IrPassMode::kExclude};
219-
std::vector<std::string> ir_passes;
219+
std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"};
220220

221221
// NOTE this is just for internal development, please not use it.
222222
bool _use_mkldnn{false};

paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) {
104104
CompareNativeAndAnalysis(cfg, input_slots_all);
105105
}
106106

107+
TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
108+
AnalysisConfig cfg;
109+
SetConfig(&cfg);
110+
// Enable embedding_fc_lstm_fuse_pass (disabled by default)
111+
auto it = std::find(cfg.ir_passes.begin(), cfg.ir_passes.end(),
112+
"embedding_fc_lstm_fuse_pass");
113+
if (it != cfg.ir_passes.end()) cfg.ir_passes.erase(it);
114+
115+
std::vector<std::vector<PaddleTensor>> input_slots_all;
116+
SetInput(&input_slots_all);
117+
CompareNativeAndAnalysis(cfg, input_slots_all);
118+
}
119+
107120
} // namespace inference
108121
} // namespace paddle

0 commit comments

Comments
 (0)