Skip to content

Commit 114eb17

Browse files
committed
fix executor bug
1 parent 612e1a3 commit 114eb17

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
337337

338338
std::unique_ptr<GarbageCollector<Tensor>> gc;
339339
if (max_memory_size >= 0) {
340+
ctx->ResetReferenceCount();
340341
#ifdef PADDLE_WITH_CUDA
341342
if (platform::is_gpu_place(place_)) {
342343
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
@@ -357,11 +358,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
357358
std::vector<std::string> erase_vars;
358359
for (auto& input : op->Inputs()) {
359360
for (auto& input_name : input.second) {
360-
auto it = ctx->ref_cnts_.find(input_name);
361-
if (it == ctx->ref_cnts_.end()) continue;
361+
auto it = ctx->cur_ref_cnts_.find(input_name);
362+
if (it == ctx->cur_ref_cnts_.end()) continue;
362363
if (it->second == 1) { // should delete it
363364
erase_vars.emplace_back(input_name);
364-
ctx->ref_cnts_.erase(input_name);
365+
ctx->cur_ref_cnts_.erase(input_name);
365366
} else {
366367
--(it->second);
367368
}
@@ -370,11 +371,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
370371

371372
for (auto& output : op->Outputs()) {
372373
for (auto& output_name : output.second) {
373-
auto it = ctx->ref_cnts_.find(output_name);
374-
if (it == ctx->ref_cnts_.end()) continue;
374+
auto it = ctx->cur_ref_cnts_.find(output_name);
375+
if (it == ctx->cur_ref_cnts_.end()) continue;
375376
if (it->second == 1) {
376377
erase_vars.emplace_back(output_name);
377-
ctx->ref_cnts_.erase(output_name);
378+
ctx->cur_ref_cnts_.erase(output_name);
378379
} else {
379380
--(it->second);
380381
}

paddle/fluid/framework/executor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,14 @@ struct ExecutorPrepareContext {
7272
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
7373
~ExecutorPrepareContext();
7474

75+
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
76+
7577
const framework::ProgramDesc& prog_;
7678
size_t block_id_;
7779
std::vector<std::unique_ptr<OperatorBase>> ops_;
7880

7981
std::unordered_map<std::string, int> ref_cnts_;
82+
std::unordered_map<std::string, int> cur_ref_cnts_;
8083
};
8184

8285
class Executor {

0 commit comments

Comments
 (0)