Skip to content

Commit a5c4b46

Browse files
committed
add SetMKLDNNThreadId api
1 parent e21edb2 commit a5c4b46

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,14 @@ bool AnalysisPredictor::PrepareExecutor() {
159159
return true;
160160
}
161161

162+
void AnalysisPredictor::SetMKLDNNThreadId(int tid) {
163+
#ifdef PADDLE_WITH_MKLDNN
164+
platform::set_cur_thread_id(tid);
165+
#else
166+
LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
167+
#endif
168+
}
169+
162170
bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
163171
std::vector<PaddleTensor> *output_data,
164172
int batch_size) {

paddle/fluid/inference/api/analysis_predictor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class AnalysisPredictor : public PaddlePredictor {
6969
framework::Scope *scope() { return scope_.get(); }
7070
framework::ProgramDesc &program() { return *inference_program_; }
7171

72+
void SetMKLDNNThreadId(int tid);
73+
7274
protected:
7375
bool PrepareProgram(const std::shared_ptr<framework::ProgramDesc> &program);
7476
bool PrepareScope(const std::shared_ptr<framework::Scope> &parent_scope);

paddle/fluid/inference/api/paddle_analysis_config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ struct AnalysisConfig : public NativeConfig {
5151
int max_batch_size = 1);
5252
bool use_tensorrt() const { return use_tensorrt_; }
5353

54+
void EnableMKLDNN();
5455
// NOTE this is just for internal development, please not use it.
5556
// NOT stable yet.
56-
void EnableMKLDNN();
5757
bool use_mkldnn() const { return use_mkldnn_; }
5858

5959
friend class ::paddle::AnalysisPredictor;

paddle/fluid/inference/tests/api/tester_helper.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,16 @@ void TestMultiThreadPrediction(
216216
size_t total_time{0};
217217
for (int tid = 0; tid < num_threads; ++tid) {
218218
threads.emplace_back([&, tid]() {
219-
#ifdef PADDLE_WITH_MKLDNN
220-
platform::set_cur_thread_id(static_cast<int>(tid) + 1);
221-
#endif
222219
// Each thread should have local inputs and outputs.
223220
// The inputs of each thread are all the same.
224221
std::vector<PaddleTensor> outputs_tid;
225222
auto &predictor = predictors[tid];
223+
#ifdef PADDLE_WITH_MKLDNN
224+
if (use_analysis) {
225+
static_cast<AnalysisPredictor *>(predictor.get())
226+
->SetMKLDNNThreadId(static_cast<int>(tid) + 1);
227+
}
228+
#endif
226229

227230
// warmup run
228231
LOG(INFO) << "Running thread " << tid << ", warm up run...";

0 commit comments

Comments
 (0)