Skip to content

Commit c6d230e

Browse files
committed
add FLAGS_use_mkldnn to global control use_mkldnn
1 parent 0aa9546 commit c6d230e

File tree

7 files changed

+29
-29
lines changed

7 files changed

+29
-29
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License. */
2424
#include "paddle/fluid/platform/profiler.h"
2525

2626
DECLARE_bool(benchmark);
27+
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
2728

2829
namespace paddle {
2930
namespace framework {
@@ -115,6 +116,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
115116
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
116117
bool create_local_scope, bool create_vars) {
117118
platform::RecordBlock b(block_id);
119+
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
118120
auto ctx = Prepare(pdesc, block_id);
119121
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
120122
}
@@ -214,6 +216,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
214216
const std::string& feed_holder_name,
215217
const std::string& fetch_holder_name) {
216218
platform::RecordBlock b(kProgramId);
219+
if (FLAGS_use_mkldnn) EnableMKLDNN(program);
217220
bool has_feed_ops =
218221
has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
219222
bool has_fetch_ops =
@@ -225,7 +228,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
225228
unique_ptr_of_copy_program.reset(new ProgramDesc(program));
226229
copy_program = unique_ptr_of_copy_program.get();
227230
}
228-
229231
auto* global_block = copy_program->MutableBlock(0);
230232

231233
if (!has_feed_ops) {
@@ -378,5 +380,19 @@ void Executor::RunPreparedContext(
378380
}
379381
}
380382

383+
void Executor::EnableMKLDNN(const ProgramDesc& program) {
384+
#ifdef PADDLE_WITH_MKLDNN
385+
VLOG(3) << "use_mkldnn=True";
386+
for (size_t bid = 0; bid < program.Size(); ++bid) {
387+
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
388+
for (auto* op : block->AllOps()) {
389+
if (op->HasAttr("use_mkldnn")) {
390+
op->SetAttr("use_mkldnn", true);
391+
}
392+
}
393+
}
394+
#endif
395+
}
396+
381397
} // namespace framework
382398
} // namespace paddle

paddle/fluid/framework/executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class Executor {
8181
const std::string& feed_holder_name = "feed",
8282
const std::string& fetch_holder_name = "fetch");
8383

84+
void EnableMKLDNN(const ProgramDesc& program);
85+
8486
private:
8587
const platform::Place place_;
8688
};

paddle/fluid/inference/tests/book/test_inference_image_classification.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
2121
DEFINE_int32(batch_size, 1, "Batch size of input data");
2222
DEFINE_int32(repeat, 1, "Running the inference program repeat times");
2323
DEFINE_bool(skip_cpu, false, "Skip the cpu test");
24-
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");
2524

2625
TEST(inference, image_classification) {
2726
if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) {
@@ -59,10 +58,8 @@ TEST(inference, image_classification) {
5958
// Run inference on CPU
6059
LOG(INFO) << "--- CPU Runs: ---";
6160
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
62-
LOG(INFO) << "FLAGS_use_mkldnn: " << FLAGS_use_mkldnn;
6361
TestInference<paddle::platform::CPUPlace, false, true>(
64-
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined,
65-
FLAGS_use_mkldnn);
62+
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined);
6663
LOG(INFO) << output1.dims();
6764
}
6865

paddle/fluid/inference/tests/book/test_inference_nlp.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ limitations under the License. */
2727
DEFINE_string(model_path, "", "Directory of the inference model.");
2828
DEFINE_string(data_file, "", "File of input index data.");
2929
DEFINE_int32(repeat, 100, "Running the inference program repeat times");
30-
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");
3130
DEFINE_bool(prepare_vars, true, "Prepare variables before executor");
3231
DEFINE_int32(num_threads, 1, "Number of threads should be used");
3332

@@ -190,9 +189,6 @@ TEST(inference, nlp) {
190189
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
191190
inference_program = InitProgram(&executor, scope.get(), FLAGS_model_path,
192191
/*model combined*/ false);
193-
if (FLAGS_use_mkldnn) {
194-
EnableMKLDNN(inference_program);
195-
}
196192
// always prepare context
197193
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
198194
ctx = executor.Prepare(*inference_program, 0);

paddle/fluid/inference/tests/test_helper.h

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ limitations under the License. */
2222
#include "paddle/fluid/inference/io.h"
2323
#include "paddle/fluid/platform/profiler.h"
2424

25+
DECLARE_bool(use_mkldnn);
26+
2527
template <typename T>
2628
void SetupTensor(paddle::framework::LoDTensor* input,
2729
paddle::framework::DDim dims, T lower, T upper) {
@@ -133,24 +135,11 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
133135
return feed_target_shapes;
134136
}
135137

136-
void EnableMKLDNN(
137-
const std::unique_ptr<paddle::framework::ProgramDesc>& program) {
138-
for (size_t bid = 0; bid < program->Size(); ++bid) {
139-
auto* block = program->MutableBlock(bid);
140-
for (auto* op : block->AllOps()) {
141-
if (op->HasAttr("use_mkldnn")) {
142-
op->SetAttr("use_mkldnn", true);
143-
}
144-
}
145-
}
146-
}
147-
148138
template <typename Place, bool CreateVars = true, bool PrepareContext = false>
149139
void TestInference(const std::string& dirname,
150140
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
151141
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
152-
const int repeat = 1, const bool is_combined = false,
153-
const bool use_mkldnn = false) {
142+
const int repeat = 1, const bool is_combined = false) {
154143
// 1. Define place, executor, scope
155144
auto place = Place();
156145
auto executor = paddle::framework::Executor(place);
@@ -182,9 +171,6 @@ void TestInference(const std::string& dirname,
182171
"init_program",
183172
paddle::platform::DeviceContextPool::Instance().Get(place));
184173
inference_program = InitProgram(&executor, scope, dirname, is_combined);
185-
if (use_mkldnn) {
186-
EnableMKLDNN(inference_program);
187-
}
188174
}
189175
// Disable the profiler and print the timing information
190176
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
@@ -210,7 +196,10 @@ void TestInference(const std::string& dirname,
210196
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
211197
}
212198

213-
// 6. Run the inference program
199+
// 6. If export Flags_use_mkldnn=True, use mkldnn related ops.
200+
if (FLAGS_use_mkldnn) executor.EnableMKLDNN(*inference_program);
201+
202+
// 7. Run the inference program
214203
{
215204
if (!CreateVars) {
216205
// If users don't want to create and destroy variables every time they

paddle/testing/paddle_gtest_main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ int main(int argc, char** argv) {
3030
new_argv.push_back(
3131
strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"));
3232
#else
33-
new_argv.push_back(strdup("--tryfromenv=use_pinned_memory"));
33+
new_argv.push_back(strdup("--tryfromenv=use_pinned_memory,use_mkldnn"));
3434
#endif
3535
int new_argc = static_cast<int>(new_argv.size());
3636
char** new_argv_address = new_argv.data();

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __bootstrap__():
116116

117117
read_env_flags = [
118118
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
119-
'eager_delete_scope'
119+
'eager_delete_scope', 'use_mkldnn'
120120
]
121121
if core.is_compiled_with_cuda():
122122
read_env_flags += [

0 commit comments

Comments
 (0)