Skip to content

Commit 6ef6a91

Browse files
authored
Merge pull request #13727 from Sand3r-/mgallus/enable-mkldnn-naive-exe
Enable MKL-DNN in Naive Executor and Analysis Predictor
2 parents 8cd17c0 + 09d9d77 commit 6ef6a91

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

paddle/fluid/framework/naive_executor.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,22 @@ void NaiveExecutor::CleanFeedFetchOps() {
146146
ops_.swap(ops);
147147
}
148148

149+
void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) {
150+
#ifdef PADDLE_WITH_MKLDNN
151+
VLOG(3) << "use_mkldnn=True";
152+
for (size_t block_id = 0; block_id < program.Size(); ++block_id) {
153+
auto *block = const_cast<ProgramDesc &>(program).MutableBlock(block_id);
154+
for (auto *op : block->AllOps()) {
155+
if (op->HasAttr("use_mkldnn")) {
156+
op->SetAttr("use_mkldnn", true);
157+
}
158+
}
159+
}
160+
#else
161+
LOG(WARNING)
162+
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
163+
#endif
164+
}
165+
149166
} // namespace framework
150167
} // namespace paddle

paddle/fluid/framework/naive_executor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <string>
18+
#include <vector>
1719
#include "paddle/fluid/framework/operator.h"
1820
#include "paddle/fluid/framework/program_desc.h"
1921
#include "paddle/fluid/framework/scope.h"
@@ -46,6 +48,8 @@ class NaiveExecutor {
4648

4749
void CleanFeedFetchOps();
4850

51+
void EnableMKLDNN(const ProgramDesc& program);
52+
4953
protected:
5054
void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id);
5155

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ bool AnalysisPredictor::Init(
7171
} else {
7272
inference_program_ = program;
7373
}
74+
75+
if (config_._use_mkldnn) {
76+
executor_->EnableMKLDNN(*inference_program_);
77+
}
78+
7479
executor_->Prepare(scope_.get(), *inference_program_, 0,
7580
config_.use_feed_fetch_ops);
7681

@@ -92,6 +97,7 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
9297
LOG(ERROR) << "fail to set feed";
9398
return false;
9499
}
100+
95101
// Run the inference program
96102
// if share variables, we need not create variables
97103
executor_->Run();

paddle/fluid/inference/tests/api/analyzer_vis_tester.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ void SetConfig(AnalysisConfig *cfg) {
6161
cfg->ir_passes.push_back("fc_gru_fuse_pass");
6262
#ifdef PADDLE_WITH_MKLDNN
6363
cfg->_use_mkldnn = true;
64-
// disable mkldnn fuse since it should have some bugs
65-
cfg->ir_passes.push_back("conv_relu_mkldnn_fuse_pass");
6664
#endif
6765
}
6866

0 commit comments

Comments
 (0)