Skip to content

Commit 79da263

Browse files
authored
Merge pull request #14032 from sfraczek/sfraczek/fix-test-multithreading-mkldnn
fix test resnet50 multi-threading on mkldnn
2 parents 26200f2 + 2098b42 commit 79da263

File tree

4 files changed

+63
-18
lines changed

4 files changed

+63
-18
lines changed

paddle/fluid/inference/api/helper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ static void PrintTime(int batch_size, int repeat, int num_threads, int tid,
160160
double latency, int epoch = 1) {
161161
LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat
162162
<< ", threads: " << num_threads << ", thread id: " << tid
163-
<< ", latency: " << latency << "ms ======";
163+
<< ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f)
164+
<< " ======";
164165
if (epoch > 1) {
165166
int samples = batch_size * epoch;
166167
LOG(INFO) << "====== sample number: " << samples

paddle/fluid/inference/tests/api/tester_helper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ void TestMultiThreadPrediction(
139139
}
140140
for (int tid = 0; tid < num_threads; ++tid) {
141141
threads.emplace_back([&, tid]() {
142+
#ifdef PADDLE_WITH_MKLDNN
143+
platform::set_cur_thread_id(static_cast<int>(tid) + 1);
144+
#endif
142145
// Each thread should have local inputs and outputs.
143146
// The inputs of each thread are all the same.
144147
std::vector<std::vector<PaddleTensor>> inputs_tid = inputs;

paddle/fluid/platform/device_context.cc

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -296,38 +296,73 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
296296

297297
#ifdef PADDLE_WITH_MKLDNN
298298
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
299-
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
300-
p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
299+
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
300+
p_blobmap_.reset(new BlobMap());
301+
p_mutex_.reset(new std::mutex());
301302
}
302303

304+
namespace {
305+
// Current thread's id.
306+
thread_local int cur_thread_id = 0;
307+
}
308+
309+
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
310+
int get_cur_thread_id(void) { return cur_thread_id; }
311+
303312
void MKLDNNDeviceContext::SetBlob(const std::string& name,
304313
std::shared_ptr<void> data) const {
305-
std::unordered_map<std::string, std::shared_ptr<void>>* p;
306-
p = p_blobs_.get();
314+
BlobMap* pMap = p_blobmap_.get();
315+
std::shared_ptr<KeyBlob> pBlob = nullptr;
316+
317+
int tid = platform::get_cur_thread_id();
307318

308-
auto it = p->find(name);
319+
std::lock_guard<std::mutex> lock(*p_mutex_.get());
309320

310-
if (it == p->end()) {
311-
(*p)[name] = data; // create new blob
321+
// Find KeyBlob for current thread
322+
auto map_it = pMap->find(tid);
323+
324+
if (map_it == pMap->end()) {
325+
// 1st time to set blob in current thread
326+
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
327+
(*pMap)[tid] = pBlob;
312328
} else {
313-
it->second = data; // set data to existing blob
329+
pBlob = map_it->second;
314330
}
315331

332+
// Find Key in found (or newly created) KeyBlob
333+
auto key_it = pBlob->find(name);
334+
335+
if (key_it == pBlob->end()) {
336+
(*pBlob)[name] = data; // create new blob
337+
} else {
338+
key_it->second = data; // set data to existing blob
339+
}
340+
341+
// lock will be automatically released when out of scope
316342
return;
317343
}
318344

319345
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
320346
const std::string& name) const {
321-
std::unordered_map<std::string, std::shared_ptr<void>>* p;
322-
p = p_blobs_.get();
347+
BlobMap* pMap = p_blobmap_.get();
348+
std::shared_ptr<KeyBlob> pBlob = nullptr;
323349

324-
auto it = p->find(name);
350+
int tid = platform::get_cur_thread_id();
325351

326-
if (it != p->end()) {
327-
return it->second;
328-
}
352+
std::lock_guard<std::mutex> lock(*p_mutex_.get());
353+
354+
// Find KeyBlob for current thread firstly
355+
auto map_it = pMap->find(tid);
356+
if (map_it == pMap->end()) return nullptr;
357+
pBlob = map_it->second;
358+
359+
// Find Blob via name
360+
auto key_it = pBlob->find(name);
361+
362+
if (key_it == pBlob->end()) return nullptr;
329363

330-
return nullptr;
364+
// lock will be automatically released when out of scope
365+
return key_it->second;
331366
}
332367

333368
#endif

paddle/fluid/platform/device_context.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
176176
#endif
177177

178178
#ifdef PADDLE_WITH_MKLDNN
179+
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
180+
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
181+
182+
void set_cur_thread_id(int);
183+
int get_cur_thread_id(void);
184+
179185
class MKLDNNDeviceContext : public CPUDeviceContext {
180186
public:
181187
explicit MKLDNNDeviceContext(CPUPlace place);
@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
191197

192198
private:
193199
mkldnn::engine engine_;
194-
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>>
195-
p_blobs_;
200+
std::shared_ptr<BlobMap> p_blobmap_;
201+
std::shared_ptr<std::mutex> p_mutex_;
196202
};
197203
#endif
198204

0 commit comments

Comments
 (0)