@@ -14,6 +14,10 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/framework/executor.h"
16
16
#include < deque>
17
+ #include < memory>
18
+ #include < unordered_map>
19
+ #include < unordered_set>
20
+ #include < utility>
17
21
18
22
#include " paddle/fluid/framework/feed_fetch_method.h"
19
23
#include " paddle/fluid/framework/lod_rank_table.h"
@@ -74,11 +78,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
74
78
75
79
ExecutorPrepareContext::ExecutorPrepareContext (
76
80
const framework::ProgramDesc& prog, size_t block_id,
77
- const std::vector<std::string>& skip_ref_cnt_vars )
78
- : prog_(prog), block_id_(block_id) {
79
- if (GetEagerDeletionThreshold () >= 0 ) {
80
- global_ref_cnts_ = GetNonPersistableReferenceCounts (prog. Block (block_id),
81
- skip_ref_cnt_vars );
81
+ const std::vector<std::string>& keep_vars, bool force_disable_gc )
82
+ : prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
83
+ if (GetEagerDeletionThreshold () >= 0 && !force_disable_gc_ ) {
84
+ global_ref_cnts_ =
85
+ GetNonPersistableReferenceCounts (prog. Block (block_id), keep_vars );
82
86
}
83
87
}
84
88
@@ -183,13 +187,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
183
187
}
184
188
185
189
void Executor::Run (const ProgramDesc& pdesc, Scope* scope, int block_id,
186
- bool create_local_scope, bool create_vars) {
190
+ bool create_local_scope, bool create_vars,
191
+ const std::vector<std::string>& skip_ref_cnt_vars,
192
+ bool force_disable_gc) {
187
193
platform::RecordBlock b (block_id);
188
194
if (FLAGS_use_mkldnn) EnableMKLDNN (pdesc);
189
195
#ifdef PADDLE_WITH_NGRAPH
190
196
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph (pdesc);
191
197
#endif
192
- auto ctx = Prepare (pdesc, block_id);
198
+ auto ctx = Prepare (pdesc, block_id, skip_ref_cnt_vars, force_disable_gc );
193
199
RunPreparedContext (ctx.get (), scope, create_local_scope, create_vars);
194
200
}
195
201
@@ -356,9 +362,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
356
362
357
363
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare (
358
364
const ProgramDesc& program, int block_id,
359
- const std::vector<std::string>& skip_ref_cnt_vars) {
360
- std::unique_ptr<ExecutorPrepareContext> ctx (
361
- new ExecutorPrepareContext ( program, block_id, skip_ref_cnt_vars));
365
+ const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc ) {
366
+ std::unique_ptr<ExecutorPrepareContext> ctx (new ExecutorPrepareContext (
367
+ program, block_id, skip_ref_cnt_vars, force_disable_gc ));
362
368
PADDLE_ENFORCE_LT (static_cast <size_t >(block_id), program.Size ());
363
369
auto & block = program.Block (block_id);
364
370
for (auto & op_desc : block.AllOps ()) {
@@ -369,7 +375,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
369
375
370
376
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare (
371
377
const ProgramDesc& program, const std::vector<int >& block_ids,
372
- const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) {
378
+ const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
379
+ bool force_disable_gc) {
373
380
PADDLE_ENFORCE (
374
381
skip_ref_cnt_vars.empty () || skip_ref_cnt_vars.size () == block_ids.size (),
375
382
" skip_ref_cnt_vars should be either empty or equals to block number %d" ,
@@ -379,9 +386,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
379
386
for (auto & bid : block_ids) {
380
387
ExecutorPrepareContext* ctx;
381
388
if (skip_ref_cnt_vars.empty ()) {
382
- ctx = new ExecutorPrepareContext (program, bid);
389
+ ctx = new ExecutorPrepareContext (program, bid, std::vector<std::string>(),
390
+ force_disable_gc);
383
391
} else {
384
- ctx = new ExecutorPrepareContext (program, bid, skip_ref_cnt_vars[idx]);
392
+ ctx = new ExecutorPrepareContext (program, bid, skip_ref_cnt_vars[idx],
393
+ force_disable_gc);
385
394
}
386
395
PADDLE_ENFORCE_LT (static_cast <size_t >(bid), program.Size ());
387
396
auto & block = program.Block (bid);
@@ -408,8 +417,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
408
417
409
418
int64_t max_memory_size = GetEagerDeletionThreshold ();
410
419
std::unique_ptr<GarbageCollector> gc;
411
- // skip while_op and while_grad_op temporarily
412
- if (max_memory_size >= 0 && !keep_kids) {
420
+ // FIXME(zjl): recurrent_op is rather complex, we would
421
+ // disable gc forcely in recurrent_op
422
+ if (!ctx->force_disable_gc_ && max_memory_size >= 0 && !keep_kids) {
413
423
ctx->ResetReferenceCount ();
414
424
#ifdef PADDLE_WITH_CUDA
415
425
if (platform::is_gpu_place (place_)) {
0 commit comments