@@ -113,10 +113,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
113
113
// and feed_holder_name. Raise exception when any mismatch is found.
114
114
// Return true if the block has feed operators and holder of matching info.
115
115
static bool has_feed_operators (
116
- BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets,
116
+ const BlockDesc& block,
117
+ std::map<std::string, const LoDTensor*>& feed_targets,
117
118
const std::string& feed_holder_name) {
118
119
size_t feed_count = 0 ;
119
- for (auto * op : block-> AllOps ()) {
120
+ for (auto * op : block. AllOps ()) {
120
121
if (op->Type () == kFeedOpType ) {
121
122
feed_count++;
122
123
PADDLE_ENFORCE_EQ (op->Input (" X" )[0 ], feed_holder_name,
@@ -135,7 +136,7 @@ static bool has_feed_operators(
135
136
" The number of feed operators should match 'feed_targets'" );
136
137
137
138
// When feed operator are present, so should be feed_holder
138
- auto var = block-> FindVar (feed_holder_name);
139
+ auto var = block. FindVar (feed_holder_name);
139
140
PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
140
141
feed_holder_name);
141
142
PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FEED_MINIBATCH,
@@ -153,10 +154,10 @@ static bool has_feed_operators(
153
154
// and fetch_holder_name. Raise exception when any mismatch is found.
154
155
// Return true if the block has fetch operators and holder of matching info.
155
156
static bool has_fetch_operators (
156
- BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets,
157
+ const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
157
158
const std::string& fetch_holder_name) {
158
159
size_t fetch_count = 0 ;
159
- for (auto * op : block-> AllOps ()) {
160
+ for (auto * op : block. AllOps ()) {
160
161
if (op->Type () == kFetchOpType ) {
161
162
fetch_count++;
162
163
PADDLE_ENFORCE_EQ (op->Output (" Out" )[0 ], fetch_holder_name,
@@ -175,7 +176,7 @@ static bool has_fetch_operators(
175
176
" The number of fetch operators should match 'fetch_targets'" );
176
177
177
178
// When fetch operator are present, so should be fetch_holder
178
- auto var = block-> FindVar (fetch_holder_name);
179
+ auto var = block. FindVar (fetch_holder_name);
179
180
PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
180
181
fetch_holder_name);
181
182
PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FETCH_LIST,
@@ -192,10 +193,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
192
193
const std::string& feed_holder_name,
193
194
const std::string& fetch_holder_name) {
194
195
platform::RecordBlock b (kProgramId );
195
- auto * copy_program = new ProgramDesc (program);
196
+ bool has_feed_ops =
197
+ has_feed_operators (program.Block (0 ), feed_targets, feed_holder_name);
198
+ bool has_fetch_ops =
199
+ has_fetch_operators (program.Block (0 ), fetch_targets, fetch_holder_name);
200
+
201
+ ProgramDesc* copy_program = const_cast <ProgramDesc*>(&program);
202
+ if (!has_feed_ops || !has_fetch_ops) {
203
+ copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc (program)).get ();
204
+ }
205
+
196
206
auto * global_block = copy_program->MutableBlock (0 );
197
207
198
- if (!has_feed_operators (global_block, feed_targets, feed_holder_name) ) {
208
+ if (!has_feed_ops ) {
199
209
// create feed_holder variable
200
210
auto * feed_holder = global_block->Var (feed_holder_name);
201
211
feed_holder->SetType (proto::VarType::FEED_MINIBATCH);
@@ -228,7 +238,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
228
238
}
229
239
}
230
240
231
- if (!has_fetch_operators (global_block, fetch_targets, fetch_holder_name) ) {
241
+ if (!has_fetch_ops ) {
232
242
// create fetch_holder variable
233
243
auto * fetch_holder = global_block->Var (fetch_holder_name);
234
244
fetch_holder->SetType (proto::VarType::FETCH_LIST);
@@ -262,8 +272,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
262
272
GetFetchVariable (*scope, fetch_holder_name, idx);
263
273
}
264
274
}
265
-
266
- delete copy_program;
267
275
}
268
276
269
277
ExecutorPrepareContext* Executor::Prepare (const ProgramDesc& program,
@@ -313,9 +321,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
313
321
} // if (create_vars)
314
322
315
323
for (auto & op : ctx->ops_ ) {
316
- VLOG (4 ) << place_ << " " << op->DebugStringEx (local_scope);
317
- op->Run (*local_scope, place_);
318
324
VLOG (3 ) << place_ << " " << op->DebugStringEx (local_scope);
325
+ op->Run (*local_scope, place_);
319
326
320
327
if (FLAGS_benchmark) {
321
328
VLOG (2 ) << " Memory used after operator " + op->Type () + " running: "
0 commit comments