File tree Expand file tree Collapse file tree 2 files changed +10
-5
lines changed
paddle/fluid/inference/api Expand file tree Collapse file tree 2 files changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -226,18 +226,21 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
226
226
argument_.origin_program_desc .reset (
227
227
new ProgramDesc (*inference_program_->Proto ()));
228
228
229
+ bool use_mkldnn = config_._use_mkldnn ;
229
230
switch (config_.ir_mode ) {
230
231
case contrib::AnalysisConfig::IrPassMode::kExclude :
231
232
Analyzer ()
232
233
.IncludeAllIrPasses ()
233
- .SetUseMkldnn (config_._use_mkldnn )
234
- .DisableIrPasses (config_.ir_passes )
234
+ .SetUseMkldnn (use_mkldnn)
235
+ .DisableIrPasses (use_mkldnn ? config_.ir_mkldnn_passes
236
+ : config_.ir_passes )
235
237
.Run (&argument_);
236
238
break ;
237
239
case contrib::AnalysisConfig::IrPassMode::kInclude :
238
240
Analyzer ()
239
- .SetUseMkldnn (config_._use_mkldnn )
240
- .IncludeIrPasses (config_.ir_passes )
241
+ .SetUseMkldnn (use_mkldnn)
242
+ .IncludeIrPasses (use_mkldnn ? config_.ir_mkldnn_passes
243
+ : config_.ir_passes )
241
244
.Run (&argument_);
242
245
break ;
243
246
default :
Original file line number Diff line number Diff line change @@ -261,8 +261,8 @@ struct AnalysisConfig : public NativeConfig {
261
261
262
262
void SetIncludeMode () {
263
263
ir_mode = IrPassMode::kInclude ;
264
- // this pass has to be run at the beginning of all fuse passes
265
264
ir_passes = {" infer_clean_graph_pass" };
265
+ ir_mkldnn_passes = {" infer_clean_graph_pass" };
266
266
}
267
267
268
268
// Determine whether to perform graph optimization.
@@ -271,6 +271,8 @@ struct AnalysisConfig : public NativeConfig {
271
271
IrPassMode ir_mode{IrPassMode::kExclude };
272
272
// passes to be excluded/included
273
273
std::vector<std::string> ir_passes{" embedding_fc_lstm_fuse_pass" };
274
+ // passes to be excluded/included when MKL-DNN is enabled
275
+ std::vector<std::string> ir_mkldnn_passes{" embedding_fc_lstm_fuse_pass" };
274
276
275
277
// NOT stable yet.
276
278
bool use_feed_fetch_ops{true };
You can’t perform that action at this time.
0 commit comments