17
17
namespace paddle {
18
18
namespace framework {
19
19
20
- // Holds all the transfer scope across the process.
21
20
std::unordered_map<size_t , Scope*>& global_transfer_data_cache () {
22
- typedef std::unordered_map<size_t , Scope*> map_t ;
23
- thread_local std::unique_ptr<map_t > x (new map_t );
21
+ thread_local auto * x = new std::unordered_map<size_t , Scope*>;
24
22
return *x;
25
23
}
26
24
27
- // Holds all the transfer scope for this thread.
28
25
std::unordered_set<Scope*>& global_transfer_scope_cache () {
29
- typedef std::unordered_set<Scope*> set_t ;
30
- thread_local std::unique_ptr<set_t > x (new set_t );
26
+ thread_local auto * x = new std::unordered_set<Scope*>;
31
27
return *x;
32
28
}
33
29
34
- // Try to create a transfer scope. If one cached scope has match the
35
- // requirement, just return that one.
36
- // Inputs:
37
- // @type0: the source kernel type.
38
- // @type1: the target kernel type.
39
- // @scope: the execution scope of this op.
40
- // Returns: A scope used to hold the transfer data across the different kernel
41
- // type.
42
30
Scope* TryCreateTransferScope (OpKernelType type0, OpKernelType type1,
43
31
const Scope* scope) {
44
32
Scope* new_scope{nullptr };
@@ -58,5 +46,27 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
58
46
return new_scope;
59
47
}
60
48
49
+ void RemoveKidsFromTransferScopeCache (Scope* scope) {
50
+ auto it = global_transfer_scope_cache ().find (scope);
51
+ if (it != global_transfer_scope_cache ().end ()) {
52
+ global_transfer_scope_cache ().erase (it);
53
+ }
54
+ for (auto * s : scope->kids ()) {
55
+ auto it = global_transfer_scope_cache ().find (s);
56
+ if (it != global_transfer_scope_cache ().end ()) {
57
+ global_transfer_scope_cache ().erase (it);
58
+ }
59
+ }
60
+
61
+ // remove global transfer data cache
62
+ auto & cache = global_transfer_data_cache ();
63
+ for (auto it = cache.begin (); it != cache.end ();) {
64
+ if (it->second == scope)
65
+ it = cache.erase (it);
66
+ else
67
+ it++;
68
+ }
69
+ }
70
+
61
71
} // namespace framework
62
72
} // namespace paddle
0 commit comments