Skip to content

Commit 9b3e79a

Browse files
committed
cherry-pick mem op to release/1.3
test=release/1.3
1 parent e0bb8cc commit 9b3e79a

File tree

13 files changed

+578
-74
lines changed

13 files changed

+578
-74
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/framework/executor.h"
1616
#include <deque>
17+
#include <memory>
18+
#include <unordered_map>
19+
#include <unordered_set>
20+
#include <utility>
1721

1822
#include "paddle/fluid/framework/feed_fetch_method.h"
1923
#include "paddle/fluid/framework/lod_rank_table.h"
@@ -74,11 +78,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
7478

7579
ExecutorPrepareContext::ExecutorPrepareContext(
7680
const framework::ProgramDesc& prog, size_t block_id,
77-
const std::vector<std::string>& skip_ref_cnt_vars)
78-
: prog_(prog), block_id_(block_id) {
79-
if (GetEagerDeletionThreshold() >= 0) {
80-
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
81-
skip_ref_cnt_vars);
81+
const std::vector<std::string>& keep_vars, bool force_disable_gc)
82+
: prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
83+
if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) {
84+
global_ref_cnts_ =
85+
GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars);
8286
}
8387
}
8488

@@ -183,13 +187,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
183187
}
184188

185189
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
186-
bool create_local_scope, bool create_vars) {
190+
bool create_local_scope, bool create_vars,
191+
const std::vector<std::string>& skip_ref_cnt_vars,
192+
bool force_disable_gc) {
187193
platform::RecordBlock b(block_id);
188194
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
189195
#ifdef PADDLE_WITH_NGRAPH
190196
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
191197
#endif
192-
auto ctx = Prepare(pdesc, block_id);
198+
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
193199
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
194200
}
195201

@@ -356,9 +362,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
356362

357363
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
358364
const ProgramDesc& program, int block_id,
359-
const std::vector<std::string>& skip_ref_cnt_vars) {
360-
std::unique_ptr<ExecutorPrepareContext> ctx(
361-
new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars));
365+
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
366+
std::unique_ptr<ExecutorPrepareContext> ctx(new ExecutorPrepareContext(
367+
program, block_id, skip_ref_cnt_vars, force_disable_gc));
362368
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
363369
auto& block = program.Block(block_id);
364370
for (auto& op_desc : block.AllOps()) {
@@ -369,7 +375,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
369375

370376
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
371377
const ProgramDesc& program, const std::vector<int>& block_ids,
372-
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) {
378+
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
379+
bool force_disable_gc) {
373380
PADDLE_ENFORCE(
374381
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
375382
"skip_ref_cnt_vars should be either empty or equals to block number %d",
@@ -379,9 +386,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
379386
for (auto& bid : block_ids) {
380387
ExecutorPrepareContext* ctx;
381388
if (skip_ref_cnt_vars.empty()) {
382-
ctx = new ExecutorPrepareContext(program, bid);
389+
ctx = new ExecutorPrepareContext(program, bid, std::vector<std::string>(),
390+
force_disable_gc);
383391
} else {
384-
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]);
392+
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx],
393+
force_disable_gc);
385394
}
386395
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
387396
auto& block = program.Block(bid);
@@ -408,8 +417,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
408417

409418
int64_t max_memory_size = GetEagerDeletionThreshold();
410419
std::unique_ptr<GarbageCollector> gc;
411-
// skip while_op and while_grad_op temporarily
412-
if (max_memory_size >= 0 && !keep_kids) {
420+
// FIXME(zjl): recurrent_op is rather complex, we would
421+
// disable gc forcely in recurrent_op
422+
if (!ctx->force_disable_gc_ && max_memory_size >= 0 && !keep_kids) {
413423
ctx->ResetReferenceCount();
414424
#ifdef PADDLE_WITH_CUDA
415425
if (platform::is_gpu_place(place_)) {

paddle/fluid/framework/executor.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <map>
18+
#include <memory>
1819
#include <string>
20+
#include <unordered_map>
1921
#include <vector>
2022
#include "paddle/fluid/framework/garbage_collector.h"
2123
#include "paddle/fluid/framework/op_info.h"
@@ -30,14 +32,16 @@ namespace framework {
3032
struct ExecutorPrepareContext {
3133
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
3234
const std::vector<std::string>& skip_ref_cnt_vars =
33-
std::vector<std::string>());
35+
std::vector<std::string>(),
36+
bool force_disable_gc = false);
3437

3538
~ExecutorPrepareContext();
3639

3740
void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; }
3841

3942
const framework::ProgramDesc& prog_;
4043
size_t block_id_;
44+
bool force_disable_gc_;
4145
std::vector<std::unique_ptr<OperatorBase>> ops_;
4246

4347
std::unordered_map<std::string, size_t> global_ref_cnts_;
@@ -66,7 +70,10 @@ class Executor {
6670
* Scope
6771
*/
6872
void Run(const ProgramDesc& prog, Scope* scope, int block_id,
69-
bool create_local_scope = true, bool create_vars = true);
73+
bool create_local_scope = true, bool create_vars = true,
74+
const std::vector<std::string>& skip_ref_cnt_vars =
75+
std::vector<std::string>(),
76+
bool force_disable_gc = false);
7077

7178
// This API is very slow.
7279
void Run(const ProgramDesc& program, Scope* scope,
@@ -79,12 +86,14 @@ class Executor {
7986
static std::unique_ptr<ExecutorPrepareContext> Prepare(
8087
const ProgramDesc& program, int block_id,
8188
const std::vector<std::string>& skip_ref_cnt_vars =
82-
std::vector<std::string>());
89+
std::vector<std::string>(),
90+
bool force_disable_gc = false);
8391

8492
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
8593
const ProgramDesc& program, const std::vector<int>& block_ids,
8694
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
87-
std::vector<std::vector<std::string>>());
95+
std::vector<std::vector<std::string>>(),
96+
bool force_disable_gc = false);
8897

8998
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
9099

0 commit comments

Comments
 (0)