Skip to content

Commit 5925b82

Browse files
authored
multithread memory optimize error fix (#37894) (#38737)
* multithread_memory_optimize
1 parent aebc5a9 commit 5925b82

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ typedef struct {
5252
// The traversal order also affect the lifecycles, so different sort_kind is
5353
// used.
5454
void MemoryOptimizePass::CollectLifeCycle(
55-
std::unordered_map<std::string, lifecycle_t>* lifecycles,
55+
Graph* graph, std::unordered_map<std::string, lifecycle_t>* lifecycles,
5656
int sort_kind) const {
57-
max_lifecycle_ = 0;
57+
int max_lifecycle = 0;
5858
for (auto* op_node : framework::ir::TopologyVarientSort(
59-
*graph_, static_cast<framework::ir::SortKind>(sort_kind))) {
59+
*graph, static_cast<framework::ir::SortKind>(sort_kind))) {
6060
if (!op_node->IsOp()) continue;
6161
auto reads = op_node->inputs;
6262
auto writes = op_node->outputs;
@@ -77,20 +77,20 @@ void MemoryOptimizePass::CollectLifeCycle(
7777
if (node->Var()->Persistable()) continue;
7878
std::string var = node->Name();
7979
if (!lifecycles->count(var)) {
80-
(*lifecycles)[var] = std::make_pair(max_lifecycle_, max_lifecycle_);
80+
(*lifecycles)[var] = std::make_pair(max_lifecycle, max_lifecycle);
8181
} else {
8282
(*lifecycles)[var].second =
83-
std::max(max_lifecycle_, lifecycles->at(var).second); // max()
83+
std::max(max_lifecycle, lifecycles->at(var).second); // max()
8484
}
8585
}
8686
}
8787

88-
++max_lifecycle_;
88+
++max_lifecycle;
8989
}
9090
}
9191

9292
void MemoryOptimizePass::CollectVarMemorySize(
93-
space_table_t* space_table) const {
93+
Graph* graph, space_table_t* space_table) const {
9494
const int fake_batch_size = 1;
9595

9696
auto valid_var = [&](framework::ir::Node* node) -> bool {
@@ -130,7 +130,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
130130
// although it's not always the case. so black list is the best compromise
131131
// between performance and underlying principle.
132132
std::unordered_set<std::string> black_list;
133-
for (auto* node : graph_->Nodes()) {
133+
for (auto* node : graph->Nodes()) {
134134
if (node->IsVar() &&
135135
node->Var()->GetType() ==
136136
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
@@ -141,7 +141,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
141141
}
142142

143143
// Collect tensors from graph.
144-
for (auto* node : graph_->Nodes()) {
144+
for (auto* node : graph->Nodes()) {
145145
if (node->IsVar() &&
146146
node->Var()->GetType() ==
147147
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR &&
@@ -304,18 +304,21 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
304304
// 3. Perform reuse plan: Replace all var's name in the model according to the
305305
// mapping table.
306306
if (!argument->enable_memory_optim()) return;
307-
graph_ = argument->main_graph_ptr();
307+
// Because of pass is a singleton, graph can not be member
308+
// variables,otherwise,errors will be caused under multithreading
309+
// conditions.
310+
auto graph = argument->main_graph_ptr();
308311

309312
int sort_kind = 0;
310313
std::unordered_map<std::string, lifecycle_t> lifecycles;
311314
space_table_t space_table;
312315
std::unordered_map<std::string, std::string> node2cluster;
313316
std::unordered_map<std::string, int> cluster_size;
314317

315-
CollectLifeCycle(&lifecycles, sort_kind);
316-
CollectVarMemorySize(&space_table);
318+
CollectLifeCycle(graph, &lifecycles, sort_kind);
319+
CollectVarMemorySize(graph, &space_table);
317320
MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size);
318-
UpdateOpDescsByReuse(graph_, node2cluster, sort_kind);
321+
UpdateOpDescsByReuse(graph, node2cluster, sort_kind);
319322
return;
320323
}
321324

paddle/fluid/inference/analysis/passes/memory_optimize_pass.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,15 @@ class MemoryOptimizePass : public AnalysisPass {
5757

5858
private:
5959
void CollectLifeCycle(
60+
framework::ir::Graph *graph,
6061
std::unordered_map<std::string, lifecycle_t> *lifecycles,
6162
int sort_kind) const;
6263

63-
void CollectVarMemorySize(space_table_t *space_table) const;
64+
void CollectVarMemorySize(framework::ir::Graph *graph,
65+
space_table_t *space_table) const;
6466

6567
public:
6668
std::string repr() const override;
67-
68-
private:
69-
mutable framework::ir::Graph *graph_{nullptr};
70-
mutable int max_lifecycle_{-1};
7169
};
7270

7371
} // namespace analysis

0 commit comments

Comments
 (0)