Skip to content

Commit 5632019

Browse files
Wojciech UssSand3r-
authored andcommitted
add MKL-DNN placement pass
This patch also refactors conv+bn (includes changes from PR #13926) updated to use the mkldnn-placement-pass. test=develop
1 parent 0a9f5f1 commit 5632019

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,18 +226,21 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
226226
argument_.origin_program_desc.reset(
227227
new ProgramDesc(*inference_program_->Proto()));
228228

229+
bool use_mkldnn = config_._use_mkldnn;
229230
switch (config_.ir_mode) {
230231
case contrib::AnalysisConfig::IrPassMode::kExclude:
231232
Analyzer()
232233
.IncludeAllIrPasses()
233-
.SetUseMkldnn(config_._use_mkldnn)
234-
.DisableIrPasses(config_.ir_passes)
234+
.SetUseMkldnn(use_mkldnn)
235+
.DisableIrPasses(use_mkldnn ? config_.ir_mkldnn_passes
236+
: config_.ir_passes)
235237
.Run(&argument_);
236238
break;
237239
case contrib::AnalysisConfig::IrPassMode::kInclude:
238240
Analyzer()
239-
.SetUseMkldnn(config_._use_mkldnn)
240-
.IncludeIrPasses(config_.ir_passes)
241+
.SetUseMkldnn(use_mkldnn)
242+
.IncludeIrPasses(use_mkldnn ? config_.ir_mkldnn_passes
243+
: config_.ir_passes)
241244
.Run(&argument_);
242245
break;
243246
default:

paddle/fluid/inference/api/paddle_inference_api.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ struct AnalysisConfig : public NativeConfig {
261261

262262
void SetIncludeMode() {
263263
ir_mode = IrPassMode::kInclude;
264-
// this pass has to be run at the beginning of all fuse passes
265264
ir_passes = {"infer_clean_graph_pass"};
265+
ir_mkldnn_passes = {"infer_clean_graph_pass"};
266266
}
267267

268268
// Determine whether to perform graph optimization.
@@ -271,6 +271,8 @@ struct AnalysisConfig : public NativeConfig {
271271
IrPassMode ir_mode{IrPassMode::kExclude};
272272
// passes to be excluded/included
273273
std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"};
274+
// passes to be excluded/included when MKL-DNN is enabled
275+
std::vector<std::string> ir_mkldnn_passes{"embedding_fc_lstm_fuse_pass"};
274276

275277
// NOT stable yet.
276278
bool use_feed_fetch_ops{true};

0 commit comments

Comments
 (0)