@@ -25,6 +25,14 @@ namespace platform {
25
25
26
26
DeviceContextPool* DeviceContextPool::pool = nullptr ;
27
27
28
+ namespace {
29
+ // Current thread's id.
30
+ thread_local int cur_thread_id = 0 ;
31
+ }
32
+
33
+ void set_cur_thread_id (int tid) { cur_thread_id = tid; }
34
+ int get_cur_thread_id (void ) { return cur_thread_id; }
35
+
28
36
platform::DeviceContext* DeviceContextPool::Get (const platform::Place& place) {
29
37
auto it = device_contexts_.find (place);
30
38
if (it == device_contexts_.end ()) {
@@ -296,38 +304,65 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
296
304
297
305
#ifdef PADDLE_WITH_MKLDNN
298
306
MKLDNNDeviceContext::MKLDNNDeviceContext (CPUPlace place)
299
- : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0 ), p_blobs_() {
300
- p_blobs_.reset (new std::unordered_map<std::string, std::shared_ptr<void >>());
307
+ : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0 ), p_blobmap_() {
308
+ p_blobmap_.reset (new BlobMap ());
309
+ p_mutex_.reset (new std::mutex ());
301
310
}
302
311
303
312
void MKLDNNDeviceContext::SetBlob (const std::string& name,
304
313
std::shared_ptr<void > data) const {
305
- std::unordered_map<std::string, std::shared_ptr<void >>* p;
306
- p = p_blobs_.get ();
314
+ BlobMap* pMap = p_blobmap_.get ();
315
+ std::shared_ptr<KeyBlob> pBlob = nullptr ;
316
+
317
+ int tid = platform::get_cur_thread_id ();
307
318
308
- auto it = p-> find (name );
319
+ std::lock_guard<std::mutex> lock (*p_mutex_. get () );
309
320
310
- if (it == p->end ()) {
311
- (*p)[name] = data; // create new blob
321
+ // Find KeyBlob for current thread
322
+ auto map_it = pMap->find (tid);
323
+
324
+ if (map_it == pMap->end ()) {
325
+ // 1st time to set blob in current thread
326
+ pBlob = std::shared_ptr<KeyBlob>(new KeyBlob ());
327
+ (*pMap)[tid] = pBlob;
312
328
} else {
313
- it-> second = data; // set data to existing blob
329
+ pBlob = map_it-> second ;
314
330
}
315
331
332
+ // Find Key in found (or newly created) KeyBlob
333
+ auto key_it = pBlob->find (name);
334
+
335
+ if (key_it == pBlob->end ()) {
336
+ (*pBlob)[name] = data; // create new blob
337
+ } else {
338
+ key_it->second = data; // set data to existing blob
339
+ }
340
+
341
+ // lock will be automatically released when out of scope
316
342
return ;
317
343
}
318
344
319
345
std::shared_ptr<void > MKLDNNDeviceContext::GetBlob (
320
346
const std::string& name) const {
321
- std::unordered_map<std::string, std::shared_ptr< void >>* p ;
322
- p = p_blobs_. get () ;
347
+ BlobMap* pMap = p_blobmap_. get () ;
348
+ std::shared_ptr<KeyBlob> pBlob = nullptr ;
323
349
324
- auto it = p-> find (name );
350
+ int tid = platform::get_cur_thread_id ( );
325
351
326
- if (it != p->end ()) {
327
- return it->second ;
328
- }
352
+ std::lock_guard<std::mutex> lock (*p_mutex_.get ());
353
+
354
+ // Find KeyBlob for current thread firstly
355
+ auto map_it = pMap->find (tid);
356
+ if (map_it == pMap->end ()) return nullptr ;
357
+ pBlob = map_it->second ;
358
+
359
+ // Find Blob via name
360
+ auto key_it = pBlob->find (name);
361
+
362
+ if (key_it == pBlob->end ()) return nullptr ;
329
363
330
- return nullptr ;
364
+ // lock will be automatically released when out of scope
365
+ return key_it->second ;
331
366
}
332
367
333
368
#endif
0 commit comments