Skip to content

Commit af91d41

Browse files
authored
Merge pull request #13852 from sneaxiy/feature/eager_delete_tensor
Fix bug of eager deletion to support if_else_op
2 parents 93606c2 + d3ed070 commit af91d41

File tree

7 files changed

+91
-80
lines changed

7 files changed

+91
-80
lines changed

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class OpHandleBase {
6464
virtual bool IsMultiDeviceTransfer() { return false; }
6565

6666
const platform::DeviceContext *DeviceContext(platform::Place place) {
67-
return dev_ctxes_[place];
67+
auto it = dev_ctxes_.find(place);
68+
return it != dev_ctxes_.end() ? it->second : nullptr;
6869
}
6970

7071
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {

paddle/fluid/framework/executor.cc

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,41 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
4646
VLOG(5) << "destroy ExecutorPrepareContext";
4747
}
4848

49+
template <typename RefCntMap>
50+
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
51+
GarbageCollector<Tensor>* gc,
52+
RefCntMap* ref_cnts) {
53+
std::unordered_set<Tensor*> erase_tensors;
54+
55+
auto handler = [&](const VariableNameMap& name_map) {
56+
for (auto& name_pair : name_map) {
57+
for (auto& name : name_pair.second) {
58+
auto it = ref_cnts->find(name);
59+
if (it == ref_cnts->end()) continue;
60+
if ((it->second)-- == 1) {
61+
auto* var = scope.FindVar(name);
62+
if (var != nullptr) {
63+
VLOG(10) << "Erase tensor \'" << name << "\'";
64+
if (var->IsType<LoDTensor>()) {
65+
erase_tensors.insert(var->GetMutable<LoDTensor>());
66+
} else if (var->IsType<SelectedRows>()) {
67+
erase_tensors.insert(
68+
var->GetMutable<SelectedRows>()->mutable_value());
69+
}
70+
}
71+
}
72+
}
73+
}
74+
};
75+
76+
handler(op->Inputs());
77+
handler(op->Outputs());
78+
79+
if (!erase_tensors.empty()) {
80+
gc->Add(erase_tensors);
81+
}
82+
}
83+
4984
Executor::Executor(const platform::Place& place) : place_(place) {}
5085

5186
void Executor::Close() {
@@ -331,9 +366,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
331366
}
332367

333368
int64_t max_memory_size = GetEagerDeletionThreshold();
334-
335369
std::unique_ptr<GarbageCollector<Tensor>> gc;
336-
if (max_memory_size >= 0) {
370+
// WhileOp would set keep_kids to false
371+
// WhileGradOp would need the scopes created in WhileOp
372+
// Perhaps, we should not perform eager deletion in WhileOp
373+
// The scopes and variables created by WhileOp would be deleted
374+
// in WhileGradOp.
375+
if (max_memory_size >= 0 && !keep_kids) {
337376
ctx->ResetReferenceCount();
338377
#ifdef PADDLE_WITH_CUDA
339378
if (platform::is_gpu_place(place_)) {
@@ -352,45 +391,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
352391
op->Run(*local_scope, place_);
353392

354393
if (gc != nullptr) {
355-
std::vector<std::string> erase_vars;
356-
for (auto& input : op->Inputs()) {
357-
for (auto& input_name : input.second) {
358-
auto it = ctx->cur_ref_cnts_.find(input_name);
359-
if (it == ctx->cur_ref_cnts_.end()) continue;
360-
if (it->second == 1) { // should delete it
361-
erase_vars.emplace_back(input_name);
362-
ctx->cur_ref_cnts_.erase(input_name);
363-
} else {
364-
--(it->second);
365-
}
366-
}
367-
}
368-
369-
for (auto& output : op->Outputs()) {
370-
for (auto& output_name : output.second) {
371-
auto it = ctx->cur_ref_cnts_.find(output_name);
372-
if (it == ctx->cur_ref_cnts_.end()) continue;
373-
if (it->second == 1) {
374-
erase_vars.emplace_back(output_name);
375-
ctx->cur_ref_cnts_.erase(output_name);
376-
} else {
377-
--(it->second);
378-
}
379-
}
380-
}
381-
382-
if (!erase_vars.empty()) {
383-
std::vector<framework::LoDTensor*> erase_tensors;
384-
for (auto& name : erase_vars) {
385-
auto* var = local_scope->FindVar(name);
386-
if (var == nullptr) continue;
387-
if (var->IsType<framework::LoDTensor>()) {
388-
auto* tensor = var->GetMutable<framework::LoDTensor>();
389-
erase_tensors.push_back(tensor);
390-
}
391-
}
392-
if (!erase_tensors.empty()) gc->Add(erase_tensors);
393-
}
394+
DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
395+
&(ctx->cur_ref_cnts_));
394396
}
395397

396398
if (FLAGS_benchmark) {

paddle/fluid/framework/executor.h

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,38 +32,32 @@ template <typename T>
3232
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
3333
const ProgramDesc& prog, size_t block_id) {
3434
auto& block = prog.Block(block_id);
35-
std::unordered_set<std::string> ignored_vars;
3635
std::unordered_map<std::string, T> ref_cnts;
3736

38-
for (auto var_desc : block.AllVars()) {
39-
auto type = var_desc->Proto()->type().type();
40-
if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) {
41-
ignored_vars.insert(var_desc->Name()); // ignore persistable vars
42-
}
43-
}
44-
45-
for (auto op_desc : block.AllOps()) {
46-
for (auto& input : op_desc->Inputs()) {
47-
for (auto& input_name : input.second) {
48-
if (!ignored_vars.count(input_name)) {
49-
if (ref_cnts.count(input_name))
50-
++ref_cnts[input_name];
51-
else
52-
ref_cnts[input_name] = 1;
37+
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
38+
for (auto& name_pair : name_map) {
39+
for (auto& name : name_pair.second) {
40+
auto* var_desc = block.FindVar(name);
41+
if (var_desc == nullptr || var_desc->Persistable()) continue;
42+
auto type = var_desc->Proto()->type().type();
43+
if (type != proto::VarType::LOD_TENSOR &&
44+
type != proto::VarType::SELECTED_ROWS) {
45+
continue;
5346
}
54-
}
55-
}
5647

57-
for (auto& output : op_desc->Outputs()) {
58-
for (auto output_name : output.second) {
59-
if (!ignored_vars.count(output_name)) {
60-
if (ref_cnts.count(output_name))
61-
++ref_cnts[output_name];
62-
else
63-
ref_cnts[output_name] = 1;
48+
auto it = ref_cnts.find(name);
49+
if (it != ref_cnts.end()) {
50+
++it->second;
51+
} else {
52+
ref_cnts[name] = 1;
6453
}
6554
}
6655
}
56+
};
57+
58+
for (auto op_desc : block.AllOps()) {
59+
update_ref_cnts(op_desc, op_desc->Inputs());
60+
update_ref_cnts(op_desc, op_desc->Outputs());
6761
}
6862
return ref_cnts;
6963
}

paddle/fluid/framework/parallel_executor.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,10 @@ ParallelExecutor::~ParallelExecutor() {
307307
}
308308
}
309309
}
310+
311+
// member_ must be destructed before gcs_ since the destructor of
312+
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
313+
member_.reset();
310314
}
311315

312316
} // namespace framework

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ParallelExecutor {
7575
private:
7676
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
7777

78-
ParallelExecutorPrivate *member_;
78+
std::unique_ptr<ParallelExecutorPrivate> member_;
7979

8080
#ifdef PADDLE_WITH_CUDA
8181
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then

paddle/fluid/framework/scope.cc

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,18 @@ int64_t GetEagerDeletionThreshold() {
4949
Scope::~Scope() { DropKids(); }
5050

5151
Scope& Scope::NewScope() const {
52-
std::unique_lock<std::mutex> lock(mutex_);
52+
std::lock_guard<std::mutex> lock(mutex_);
5353
kids_.push_back(new Scope(this));
5454
return *kids_.back();
5555
}
5656

5757
Variable* Scope::Var(const std::string& name) {
58-
std::unique_lock<std::mutex> lock(mutex_);
58+
std::lock_guard<std::mutex> lock(mutex_);
5959
return VarInternal(name);
6060
}
6161

6262
Variable* Scope::Var(std::string* name) {
63-
std::unique_lock<std::mutex> lock(mutex_);
63+
std::lock_guard<std::mutex> lock(mutex_);
6464
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
6565
if (name != nullptr) {
6666
*name = new_name;
@@ -69,29 +69,34 @@ Variable* Scope::Var(std::string* name) {
6969
}
7070

7171
Variable* Scope::FindVar(const std::string& name) const {
72-
std::unique_lock<std::mutex> lock(mutex_);
72+
std::lock_guard<std::mutex> lock(mutex_);
7373
return FindVarInternal(name);
7474
}
7575

76+
Variable* Scope::FindLocalVar(const std::string& name) const {
77+
std::lock_guard<std::mutex> lock(mutex_);
78+
return FindVarLocally(name);
79+
}
80+
7681
const Scope* Scope::FindScope(const Variable* var) const {
77-
std::unique_lock<std::mutex> lock(mutex_);
82+
std::lock_guard<std::mutex> lock(mutex_);
7883
return FindScopeInternal(var);
7984
}
8085

8186
void Scope::DropKids() {
82-
std::unique_lock<std::mutex> lock(mutex_);
87+
std::lock_guard<std::mutex> lock(mutex_);
8388
for (Scope* s : kids_) delete s;
8489
kids_.clear();
8590
}
8691

8792
bool Scope::HasKid(const Scope* scope) const {
88-
std::unique_lock<std::mutex> lock(mutex_);
93+
std::lock_guard<std::mutex> lock(mutex_);
8994
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
9095
return it != this->kids_.end();
9196
}
9297

9398
std::vector<std::string> Scope::LocalVarNames() const {
94-
std::unique_lock<std::mutex> lock(mutex_);
99+
std::lock_guard<std::mutex> lock(mutex_);
95100
std::vector<std::string> known_vars;
96101
known_vars.reserve(this->vars_.size());
97102
for (auto& p : vars_) {
@@ -101,7 +106,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
101106
}
102107

103108
void Scope::DeleteScope(Scope* scope) const {
104-
std::unique_lock<std::mutex> lock(mutex_);
109+
std::lock_guard<std::mutex> lock(mutex_);
105110
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
106111
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
107112
this->kids_.erase(it);
@@ -114,7 +119,7 @@ void Scope::DeleteScope(Scope* scope) const {
114119
}
115120

116121
void Scope::EraseVars(const std::vector<std::string>& var_names) {
117-
std::unique_lock<std::mutex> lock(mutex_);
122+
std::lock_guard<std::mutex> lock(mutex_);
118123
std::set<std::string> var_set(var_names.begin(), var_names.end());
119124
for (auto it = vars_.begin(); it != vars_.end();) {
120125
if (var_set.find(it->first) != var_set.end()) {
@@ -127,12 +132,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
127132

128133
void Scope::Rename(const std::string& origin_name,
129134
const std::string& new_name) const {
130-
std::unique_lock<std::mutex> lock(mutex_);
135+
std::lock_guard<std::mutex> lock(mutex_);
131136
RenameInternal(origin_name, new_name);
132137
}
133138

134139
std::string Scope::Rename(const std::string& origin_name) const {
135-
std::unique_lock<std::mutex> lock(mutex_);
140+
std::lock_guard<std::mutex> lock(mutex_);
136141
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
137142
RenameInternal(origin_name, new_name);
138143
return new_name;

paddle/fluid/framework/scope.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ class Scope {
6363
/// Caller doesn't own the returned Variable.
6464
Variable* FindVar(const std::string& name) const;
6565

66+
/// Find a variable in the current scope.
67+
/// Return nullptr if cannot find.
68+
/// Caller doesn't own the returned Variable.
69+
Variable* FindLocalVar(const std::string& name) const;
70+
6671
const Scope* parent() const { return parent_; }
6772

6873
/// Find the scope or an ancestor scope that contains the given variable.

0 commit comments

Comments
 (0)