Skip to content

Commit a53e8a8

Browse files
liujianhang-designsfraczek
authored andcommitted
Update MKLDNN integration framework to support Paddle multi-instances
Make all blob info saved in global device context to be thread based. Meanwhile save thread id in thread local storage in ParallelDo
1 parent 2256fae commit a53e8a8

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

paddle/fluid/platform/device_context.cc

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ namespace platform {
2525

2626
DeviceContextPool* DeviceContextPool::pool = nullptr;
2727

28+
namespace {
29+
// Current thread's id.
30+
thread_local int cur_thread_id = 0;
31+
}
32+
33+
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
34+
int get_cur_thread_id(void) { return cur_thread_id; }
35+
2836
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
2937
auto it = device_contexts_.find(place);
3038
if (it == device_contexts_.end()) {
@@ -296,38 +304,65 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
296304

297305
#ifdef PADDLE_WITH_MKLDNN
298306
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
299-
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
300-
p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
307+
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
308+
p_blobmap_.reset(new BlobMap());
309+
p_mutex_.reset(new std::mutex());
301310
}
302311

303312
void MKLDNNDeviceContext::SetBlob(const std::string& name,
304313
std::shared_ptr<void> data) const {
305-
std::unordered_map<std::string, std::shared_ptr<void>>* p;
306-
p = p_blobs_.get();
314+
BlobMap* pMap = p_blobmap_.get();
315+
std::shared_ptr<KeyBlob> pBlob = nullptr;
316+
317+
int tid = platform::get_cur_thread_id();
307318

308-
auto it = p->find(name);
319+
std::lock_guard<std::mutex> lock(*p_mutex_.get());
309320

310-
if (it == p->end()) {
311-
(*p)[name] = data; // create new blob
321+
// Find KeyBlob for current thread
322+
auto map_it = pMap->find(tid);
323+
324+
if (map_it == pMap->end()) {
325+
// 1st time to set blob in current thread
326+
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
327+
(*pMap)[tid] = pBlob;
312328
} else {
313-
it->second = data; // set data to existing blob
329+
pBlob = map_it->second;
314330
}
315331

332+
// Find Key in found (or newly created) KeyBlob
333+
auto key_it = pBlob->find(name);
334+
335+
if (key_it == pBlob->end()) {
336+
(*pBlob)[name] = data; // create new blob
337+
} else {
338+
key_it->second = data; // set data to existing blob
339+
}
340+
341+
// lock will be automatically released when out of scope
316342
return;
317343
}
318344

319345
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
320346
const std::string& name) const {
321-
std::unordered_map<std::string, std::shared_ptr<void>>* p;
322-
p = p_blobs_.get();
347+
BlobMap* pMap = p_blobmap_.get();
348+
std::shared_ptr<KeyBlob> pBlob = nullptr;
323349

324-
auto it = p->find(name);
350+
int tid = platform::get_cur_thread_id();
325351

326-
if (it != p->end()) {
327-
return it->second;
328-
}
352+
std::lock_guard<std::mutex> lock(*p_mutex_.get());
353+
354+
// Find KeyBlob for current thread firstly
355+
auto map_it = pMap->find(tid);
356+
if (map_it == pMap->end()) return nullptr;
357+
pBlob = map_it->second;
358+
359+
// Find Blob via name
360+
auto key_it = pBlob->find(name);
361+
362+
if (key_it == pBlob->end()) return nullptr;
329363

330-
return nullptr;
364+
// lock will be automatically released when out of scope
365+
return key_it->second;
331366
}
332367

333368
#endif

paddle/fluid/platform/device_context.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ limitations under the License. */
3939
namespace paddle {
4040
namespace platform {
4141

42+
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
43+
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
44+
45+
void set_cur_thread_id(int);
46+
int get_cur_thread_id(void);
47+
4248
class DeviceContext {
4349
public:
4450
virtual ~DeviceContext() {}
@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
191197

192198
private:
193199
mkldnn::engine engine_;
194-
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>>
195-
p_blobs_;
200+
std::shared_ptr<BlobMap> p_blobmap_;
201+
std::shared_ptr<std::mutex> p_mutex_;
196202
};
197203
#endif
198204

0 commit comments

Comments
 (0)