Skip to content

Commit db2b6b6

Browse files
Hide globals & redesign restore PR (#24279)
test=develop
1 parent 4a105f8 commit db2b6b6

File tree

10 files changed

+136
-90
lines changed

10 files changed

+136
-90
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
124124
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
125125
"non-MKLDNN");
126126

127-
innerTransDataLayoutFromMKLDNN(in_layout,
128-
paddle::platform::get_cur_paddle_data_layout(),
129-
in, out, place);
127+
innerTransDataLayoutFromMKLDNN(
128+
in_layout,
129+
paddle::platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout(),
130+
in, out, place);
130131
}
131132

132133
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,

paddle/fluid/framework/data_transform.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
5959
// For NHWC data we need reshape of tensors as MKL-DNN
6060
// is expecting NHWC dims description order
6161
platform::MatchShapeToLayout(&out, lin, lout);
62-
paddle::platform::set_cur_paddle_data_layout(lin);
62+
paddle::platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
63+
lin);
6364
out.set_layout(DataLayout::kMKLDNN);
6465
out.set_format(out_format);
6566
} else {

paddle/fluid/framework/executor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ Executor::~Executor() {
8989
platform::MKLDNNDeviceContext* dev_ctx =
9090
(platform::MKLDNNDeviceContext*)pool.Get(place_);
9191
dev_ctx->ResetBlobMap();
92-
platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW);
92+
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
93+
paddle::framework::DataLayout::kNCHW);
9394
}
9495
#endif
9596
}

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData(
11551155
if ((tensor_in->layout() == DataLayout::kMKLDNN) &&
11561156
(var->IsType<LoDTensor>() == true) &&
11571157
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
1158-
(paddle::platform::get_cur_paddle_data_layout() ==
1159-
DataLayout::kNHWC)) {
1158+
(paddle::platform::MKLDNNDeviceContext::tls()
1159+
.get_cur_paddle_data_layout() == DataLayout::kNHWC)) {
11601160
// Mixed execution : MKL-DNN and GPU is not supported!
11611161
if (!new_scope) {
11621162
new_scope = &scope.NewScope();

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() {
244244
void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
245245
#ifdef PADDLE_WITH_MKLDNN
246246
VLOG(2) << "AnalysisPredictor::Run get_cur_mkldnn_session_id="
247-
<< platform::get_cur_mkldnn_session_id();
247+
<< platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id();
248248
// In cache clearing mode.
249249
if (config_.mkldnn_cache_capacity_ > 0) {
250250
VLOG(2) << "In mkldnn cache clear mode.";
251-
platform::set_cur_mkldnn_session_id(
252-
platform::kMKLDNNSessionID_CacheClearing);
253-
platform::set_cur_input_shape_cache_capacity(
251+
platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
252+
platform::MKLDNNDeviceContextThreadLocals::
253+
kMKLDNNSessionID_CacheClearing);
254+
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(
254255
config_.mkldnn_cache_capacity_);
255256
// Set current_input_shape for caching dynamic shape.
256257
std::stringstream ss;
@@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
260261
}
261262
}
262263
VLOG(2) << "Set input shape=" << ss.str();
263-
platform::set_cur_input_shape_str(ss.str());
264+
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(ss.str());
264265
}
265266
#endif
266267
}
@@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() {
277278
CHECK_LE(shape_blob_size,
278279
static_cast<size_t>(config_.mkldnn_cache_capacity_));
279280
}
280-
paddle::platform::set_cur_mkldnn_session_id(
281-
platform::kMKLDNNSessionID_Default);
282-
platform::set_cur_input_shape_cache_capacity(0);
283-
platform::set_cur_input_shape_str("");
281+
paddle::platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
282+
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default);
283+
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(0);
284+
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str("");
284285
}
285286
#endif
286287
}

paddle/fluid/operators/controlflow/fetch_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item,
3434
// Convert to desired Paddle layout, apart from grads of filter
3535
// as params are not a subject to paddle's data_format
3636
framework::innerTransDataLayoutFromMKLDNN(
37-
src_item.layout(),
38-
fetch_var_name == framework::GradVarName("Filter")
39-
? framework::DataLayout::kNCHW
40-
: paddle::platform::get_cur_paddle_data_layout(),
37+
src_item.layout(), fetch_var_name == framework::GradVarName("Filter")
38+
? framework::DataLayout::kNCHW
39+
: paddle::platform::MKLDNNDeviceContext::tls()
40+
.get_cur_paddle_data_layout(),
4141
src_item, &out, platform::CPUPlace());
4242
TensorCopySync(out, platform::CPUPlace(), dst_item);
4343
} else {

paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
446446
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
447447
// are merged/unified, this will disappear
448448
std::string key_tid = "";
449-
if (platform::get_cur_mkldnn_session_id() ==
450-
platform::kMKLDNNSessionID_Default) {
449+
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
450+
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
451451
key_tid = "-t:" + platform::ThreadIDasStr();
452452
}
453453

paddle/fluid/platform/device_context.cc

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
375375
p_mutex_.reset(new std::mutex());
376376
}
377377

378-
namespace {
379-
// Current mkldnn session id.
380-
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
381-
// Current data input shape string.
382-
// - For fixed-shape, it's a null string in default.
383-
// - For dynamic-shape, it's user specific.
384-
thread_local std::string cur_input_shape_str = "";
385-
// the cache capacity of different input shapes for MKLDNN.
386-
// Default 1 means fixed input shape, not dynamic shape.
387-
thread_local int cur_input_shape_cache_capacity = 1;
388-
// Recently registered data_format. This is needed to
389-
// know for converting MKL-DNN Tensor to non MKL-DNN
390-
thread_local paddle::framework::DataLayout cur_paddle_data_layout =
391-
paddle::framework::DataLayout::kNCHW;
392-
} // namespace
393-
394-
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
395-
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
396-
void set_cur_input_shape_str(std::string input_shape_str) {
378+
MKLDNNDeviceContextThreadLocals::Body::Body() {
379+
cur_mkldnn_session_id = kMKLDNNSessionID_Default;
380+
cur_input_shape_str = "";
381+
cur_input_shape_cache_capacity = 1;
382+
cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
383+
}
384+
385+
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
386+
size_t sid) {
387+
cur_mkldnn_session_id = sid;
388+
}
389+
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
390+
return cur_mkldnn_session_id;
391+
}
392+
393+
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
394+
std::string input_shape_str) {
397395
cur_input_shape_str = input_shape_str;
398396
}
399-
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
397+
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
398+
int input_shape_cache_capacity) {
400399
cur_input_shape_cache_capacity = input_shape_cache_capacity;
401400
}
402401

403-
void set_cur_paddle_data_layout(framework::DataLayout dl) {
402+
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
403+
framework::DataLayout dl) {
404404
cur_paddle_data_layout = dl;
405405
}
406406

407-
framework::DataLayout get_cur_paddle_data_layout(void) {
407+
framework::DataLayout
408+
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
408409
return cur_paddle_data_layout;
409410
}
410411

@@ -414,54 +415,55 @@ void MKLDNNDeviceContext::ResetBlobMap() const {
414415
}
415416

416417
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
417-
std::lock_guard<std::mutex> lock(*p_mutex_);
418+
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
418419
BlobMap* pMap = p_blobmap_.get();
419-
auto map_it = pMap->find(cur_mkldnn_session_id);
420+
auto map_it = pMap->find(tls().cur_mkldnn_session_id);
420421
if (map_it == pMap->end()) {
421422
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
422-
<< cur_mkldnn_session_id;
423+
<< tls().cur_mkldnn_session_id;
423424
}
424425
return map_it->second->size();
425426
}
426427

427428
void MKLDNNDeviceContext::SetBlob(const std::string& name,
428-
std::shared_ptr<void> data) const {
429+
BlobPtr_t<void> data) const {
429430
BlobMap* pMap = p_blobmap_.get();
430-
std::shared_ptr<ShapeBlob> sBlob = nullptr;
431-
std::shared_ptr<KeyBlob> pBlob = nullptr;
431+
BlobPtr_t<ShapeBlob> sBlob = nullptr;
432+
BlobPtr_t<KeyBlob> pBlob = nullptr;
432433

433-
int sid = platform::get_cur_mkldnn_session_id();
434+
int sid = tls().get_cur_mkldnn_session_id();
434435

435-
std::lock_guard<std::mutex> lock(*p_mutex_);
436+
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
436437

437438
// Find ShapeBlob for current mkldnn session id.
438439
auto map_it = pMap->find(sid);
439440

440441
if (map_it == pMap->end()) {
441442
// 1st time to set blob in current thread
442-
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
443+
sBlob = std::make_shared<ShapeBlob>();
443444
(*pMap)[sid] = sBlob;
444445
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
445446
} else {
446447
sBlob = map_it->second;
447448
}
448449

449450
// Find KeyBlob for current input shape
450-
auto key_it = sBlob->find(cur_input_shape_str);
451+
auto key_it = sBlob->find(tls().cur_input_shape_str);
451452

452453
if (key_it == sBlob->end()) {
453454
// In cache clearing mode, cur_input_shape_cache_capacity defines
454455
// max pblob capacity
455-
if ((static_cast<size_t>(sid) == kMKLDNNSessionID_CacheClearing) &&
456+
if ((static_cast<size_t>(sid) ==
457+
MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
456458
sBlob->size() &&
457459
(sBlob->size() >=
458-
static_cast<size_t>(cur_input_shape_cache_capacity))) {
460+
static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
459461
VLOG(2) << "sid=" << sid
460462
<< ", remove all blobs of shape: " << sBlob->begin()->first;
461463
sBlob->erase(sBlob->begin()->first);
462464
}
463-
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
464-
(*sBlob)[cur_input_shape_str] = pBlob;
465+
pBlob = std::make_shared<KeyBlob>();
466+
(*sBlob)[tls().cur_input_shape_str] = pBlob;
465467
} else {
466468
pBlob = key_it->second;
467469
}
@@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
478480
return;
479481
}
480482

481-
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
483+
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
482484
const std::string& name) const {
483485
BlobMap* pMap = p_blobmap_.get();
484-
std::shared_ptr<ShapeBlob> sBlob = nullptr;
485-
std::shared_ptr<KeyBlob> pBlob = nullptr;
486+
BlobPtr_t<ShapeBlob> sBlob = nullptr;
487+
BlobPtr_t<KeyBlob> pBlob = nullptr;
486488

487-
int sid = platform::get_cur_mkldnn_session_id();
489+
int sid = tls().get_cur_mkldnn_session_id();
488490

489-
std::lock_guard<std::mutex> lock(*p_mutex_);
491+
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
490492

491493
// Find ShapeBlob for current mkldnn session id firstly
492494
auto map_it = pMap->find(sid);
@@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
497499
sBlob = map_it->second;
498500

499501
// Find KeyBlob for current input shape secondly
500-
auto sBlob_it = sBlob->find(cur_input_shape_str);
502+
auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
501503
if (sBlob_it == sBlob->end()) {
502-
VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
504+
VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
503505
<< ", miss input_shape_str\n";
504506
return nullptr;
505507
}

paddle/fluid/platform/device_context.h

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -421,30 +421,66 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
421421
#endif
422422

423423
#ifdef PADDLE_WITH_MKLDNN
424-
// Following three maps are used to cache MKLDNN primitives.
425-
// There relations are:
426-
// - BlobMap = Map<cur_thread_id, ShapeBlob>
427-
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
428-
// - KeyBlob = Map<blob_name, blob>
429-
// Where:
430-
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
431-
using ShapeBlob = std::unordered_map<std::string, std::shared_ptr<KeyBlob>>;
432-
using BlobMap = std::unordered_map<int, std::shared_ptr<ShapeBlob>>;
433-
434-
// default mkldnn session id
435-
constexpr size_t kMKLDNNSessionID_Default = 0;
436-
// mkldnn session id for cache clearing mode
437-
constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
438-
439-
void set_cur_mkldnn_session_id(size_t);
440-
size_t get_cur_mkldnn_session_id(void);
441-
void set_cur_input_shape_str(std::string input_shape_str);
442-
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
443-
void set_cur_paddle_data_layout(framework::DataLayout);
444-
framework::DataLayout get_cur_paddle_data_layout(void);
424+
425+
class MKLDNNDeviceContextThreadLocals {
426+
// default mkldnn session id
427+
428+
typedef MKLDNNDeviceContextThreadLocals self;
429+
struct Body {
430+
size_t cur_mkldnn_session_id;
431+
// Current data input shape string.
432+
// - For fixed-shape, it's a null string in default.
433+
// - For dynamic-shape, it's user specific.
434+
std::string cur_input_shape_str;
435+
// the cache capacity of different input shapes for MKLDNN.
436+
// Default 1 means fixed input shape, not dynamic shape.
437+
int cur_input_shape_cache_capacity;
438+
// Recently registered data_format. This is needed to
439+
// know for converting MKL-DNN Tensor to non MKL-DNN
440+
paddle::framework::DataLayout cur_paddle_data_layout;
441+
442+
Body();
443+
void set_cur_mkldnn_session_id(size_t sid);
444+
size_t get_cur_mkldnn_session_id(void);
445+
void set_cur_input_shape_str(std::string input_shape_str);
446+
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
447+
void set_cur_paddle_data_layout(framework::DataLayout dl);
448+
framework::DataLayout get_cur_paddle_data_layout(void);
449+
};
450+
MKLDNNDeviceContextThreadLocals() = default;
451+
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
452+
delete;
453+
454+
public:
455+
// default mkldnn session id
456+
static constexpr size_t kMKLDNNSessionID_Default = 0;
457+
// mkldnn session id for cache clearing mode
458+
static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
459+
static Body& fetch() {
460+
thread_local Body b;
461+
return b;
462+
}
463+
};
445464

446465
class MKLDNNDeviceContext : public CPUDeviceContext {
447466
public:
467+
template <class T>
468+
using BlobPtr_t = std::shared_ptr<T>;
469+
template <class P1, class P2>
470+
using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
471+
template <class T>
472+
using umap_key_string_t = umap_value_smart_t<std::string, T>;
473+
474+
// Following three maps are used to cache MKLDNN primitives.
475+
// There relations are:
476+
// - BlobMap = Map<cur_thread_id, ShapeBlob>
477+
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
478+
// - KeyBlob = Map<blob_name, blob>
479+
480+
using KeyBlob = umap_key_string_t<void>;
481+
using ShapeBlob = umap_key_string_t<KeyBlob>;
482+
using BlobMap = umap_value_smart_t<int, ShapeBlob>;
483+
448484
explicit MKLDNNDeviceContext(CPUPlace place);
449485

450486
/* \brief Get the active engine */
@@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
462498
// Find a saved blob. Return nullptr if not found
463499
std::shared_ptr<void> GetBlob(const std::string& name) const;
464500

501+
static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
502+
return MKLDNNDeviceContextThreadLocals::fetch();
503+
}
504+
465505
private:
466506
mkldnn::engine engine_;
467507
std::shared_ptr<BlobMap> p_blobmap_;

paddle/fluid/platform/mkldnn_reuse.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class MKLDNNHandlerT {
4242
key_common_(base_key),
4343
fwd_pd_(nullptr),
4444
bwd_pd_(nullptr) {
45-
if (platform::get_cur_mkldnn_session_id() !=
46-
platform::kMKLDNNSessionID_Default) {
45+
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
46+
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
4747
key_ = key_common_;
4848
} else {
4949
key_ = key_common_ + "-t:" + ThreadIDasStr();
@@ -177,8 +177,8 @@ class MKLDNNHandler {
177177
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
178178
const std::string& base_key)
179179
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
180-
if (platform::get_cur_mkldnn_session_id() !=
181-
platform::kMKLDNNSessionID_Default) {
180+
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
181+
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
182182
key_ = key_common_;
183183
} else {
184184
key_ = key_common_ + "-t:" + ThreadIDasStr();

0 commit comments

Comments
 (0)