@@ -52,11 +52,11 @@ typedef struct {
52
52
// The traversal order also affect the lifecycles, so different sort_kind is
53
53
// used.
54
54
void MemoryOptimizePass::CollectLifeCycle (
55
- std::unordered_map<std::string, lifecycle_t >* lifecycles,
55
+ Graph* graph, std::unordered_map<std::string, lifecycle_t >* lifecycles,
56
56
int sort_kind) const {
57
- max_lifecycle_ = 0 ;
57
+ int max_lifecycle = 0 ;
58
58
for (auto * op_node : framework::ir::TopologyVarientSort (
59
- *graph_ , static_cast <framework::ir::SortKind>(sort_kind))) {
59
+ *graph , static_cast <framework::ir::SortKind>(sort_kind))) {
60
60
if (!op_node->IsOp ()) continue ;
61
61
auto reads = op_node->inputs ;
62
62
auto writes = op_node->outputs ;
@@ -77,20 +77,20 @@ void MemoryOptimizePass::CollectLifeCycle(
77
77
if (node->Var ()->Persistable ()) continue ;
78
78
std::string var = node->Name ();
79
79
if (!lifecycles->count (var)) {
80
- (*lifecycles)[var] = std::make_pair (max_lifecycle_, max_lifecycle_ );
80
+ (*lifecycles)[var] = std::make_pair (max_lifecycle, max_lifecycle );
81
81
} else {
82
82
(*lifecycles)[var].second =
83
- std::max (max_lifecycle_ , lifecycles->at (var).second ); // max()
83
+ std::max (max_lifecycle , lifecycles->at (var).second ); // max()
84
84
}
85
85
}
86
86
}
87
87
88
- ++max_lifecycle_ ;
88
+ ++max_lifecycle ;
89
89
}
90
90
}
91
91
92
92
void MemoryOptimizePass::CollectVarMemorySize (
93
- space_table_t * space_table) const {
93
+ Graph* graph, space_table_t * space_table) const {
94
94
const int fake_batch_size = 1 ;
95
95
96
96
auto valid_var = [&](framework::ir::Node* node) -> bool {
@@ -130,7 +130,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
130
130
// although it's not always the case. so black list is the best compromise
131
131
// between performance and underlying principle.
132
132
std::unordered_set<std::string> black_list;
133
- for (auto * node : graph_ ->Nodes ()) {
133
+ for (auto * node : graph ->Nodes ()) {
134
134
if (node->IsVar () &&
135
135
node->Var ()->GetType () ==
136
136
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
@@ -141,7 +141,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
141
141
}
142
142
143
143
// Collect tensors from graph.
144
- for (auto * node : graph_ ->Nodes ()) {
144
+ for (auto * node : graph ->Nodes ()) {
145
145
if (node->IsVar () &&
146
146
node->Var ()->GetType () ==
147
147
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR &&
@@ -304,18 +304,21 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
304
304
// 3. Perform reuse plan: Replace all var's name in the model according to the
305
305
// mapping table.
306
306
if (!argument->enable_memory_optim ()) return ;
307
- graph_ = argument->main_graph_ptr ();
307
+ // Because of pass is a singleton, graph can not be member
308
+ // variables,otherwise,errors will be caused under multithreading
309
+ // conditions.
310
+ auto graph = argument->main_graph_ptr ();
308
311
309
312
int sort_kind = 0 ;
310
313
std::unordered_map<std::string, lifecycle_t > lifecycles;
311
314
space_table_t space_table;
312
315
std::unordered_map<std::string, std::string> node2cluster;
313
316
std::unordered_map<std::string, int > cluster_size;
314
317
315
- CollectLifeCycle (&lifecycles, sort_kind);
316
- CollectVarMemorySize (&space_table);
318
+ CollectLifeCycle (graph, &lifecycles, sort_kind);
319
+ CollectVarMemorySize (graph, &space_table);
317
320
MakeSimpleReusePlan (lifecycles, space_table, &node2cluster, &cluster_size);
318
- UpdateOpDescsByReuse (graph_ , node2cluster, sort_kind);
321
+ UpdateOpDescsByReuse (graph , node2cluster, sort_kind);
319
322
return ;
320
323
}
321
324
0 commit comments