@@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name,
83
83
if (tensor.memory_size () == 0 ) {
84
84
return ;
85
85
}
86
- if (tensor.type ().hash_code () != typeid (float ).hash_code () &&
87
- tensor.type ().hash_code () != typeid (double ).hash_code ()) {
86
+ if (tensor.type ().hash_code () != typeid (float ).hash_code () && // NOLINT
87
+ tensor.type ().hash_code () != typeid (double ).hash_code ()) { // NOLINT
88
88
return ;
89
89
}
90
90
PADDLE_ENFORCE (!framework::TensorContainsInf (tensor),
@@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
145
145
// Return true if the block has feed operators and holder of matching info.
146
146
static bool has_feed_operators (
147
147
const BlockDesc& block,
148
- std::map<std::string, const LoDTensor*>& feed_targets,
148
+ const std::map<std::string, const LoDTensor*>& feed_targets,
149
149
const std::string& feed_holder_name) {
150
150
size_t feed_count = 0 ;
151
151
for (auto * op : block.AllOps ()) {
152
152
if (op->Type () == kFeedOpType ) {
153
153
feed_count++;
154
+ // The input variable's name of feed_op should be feed_holder_name.
154
155
PADDLE_ENFORCE_EQ (op->Input (" X" )[0 ], feed_holder_name,
155
156
" Input to feed op should be '%s'" , feed_holder_name);
156
157
std::string feed_target_name = op->Output (" Out" )[0 ];
@@ -166,13 +167,15 @@ static bool has_feed_operators(
166
167
feed_count, feed_targets.size (),
167
168
" The number of feed operators should match 'feed_targets'" );
168
169
169
- // When feed operator are present, so should be feed_holder
170
- auto var = block.FindVar (feed_holder_name);
171
- PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
172
- feed_holder_name);
173
- PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FEED_MINIBATCH,
174
- " '%s' variable should be 'FEED_MINIBATCH' type" ,
175
- feed_holder_name);
170
+ if (!feed_holder_name.empty ()) {
171
+ // When feed operator are present, so should be feed_holder.
172
+ auto var = block.FindVar (feed_holder_name);
173
+ PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
174
+ feed_holder_name);
175
+ PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FEED_MINIBATCH,
176
+ " '%s' variable should be 'FEED_MINIBATCH' type" ,
177
+ feed_holder_name);
178
+ }
176
179
}
177
180
178
181
return feed_count > 0 ;
@@ -185,12 +188,14 @@ static bool has_feed_operators(
185
188
// and fetch_holder_name. Raise exception when any mismatch is found.
186
189
// Return true if the block has fetch operators and holder of matching info.
187
190
static bool has_fetch_operators (
188
- const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
191
+ const BlockDesc& block,
192
+ const std::map<std::string, LoDTensor*>& fetch_targets,
189
193
const std::string& fetch_holder_name) {
190
194
size_t fetch_count = 0 ;
191
195
for (auto * op : block.AllOps ()) {
192
196
if (op->Type () == kFetchOpType ) {
193
197
fetch_count++;
198
+ // The output variable's name of fetch_op should be fetch_holder_name.
194
199
PADDLE_ENFORCE_EQ (op->Output (" Out" )[0 ], fetch_holder_name,
195
200
" Output of fetch op should be '%s'" , fetch_holder_name);
196
201
std::string fetch_target_name = op->Input (" X" )[0 ];
@@ -206,13 +211,15 @@ static bool has_fetch_operators(
206
211
fetch_count, fetch_targets.size (),
207
212
" The number of fetch operators should match 'fetch_targets'" );
208
213
209
- // When fetch operator are present, so should be fetch_holder
210
- auto var = block.FindVar (fetch_holder_name);
211
- PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
212
- fetch_holder_name);
213
- PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FETCH_LIST,
214
- " '%s' variable should be 'FETCH_LIST' type" ,
215
- fetch_holder_name);
214
+ if (!fetch_holder_name.empty ()) {
215
+ // When fetch operator are present, so should be fetch_holder.
216
+ auto var = block.FindVar (fetch_holder_name);
217
+ PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
218
+ fetch_holder_name);
219
+ PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FETCH_LIST,
220
+ " '%s' variable should be 'FETCH_LIST' type" ,
221
+ fetch_holder_name);
222
+ }
216
223
}
217
224
218
225
return fetch_count > 0 ;
@@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
259
266
}
260
267
}
261
268
262
- // map the data of feed_targets to feed_holder
263
- for (auto * op : global_block->AllOps ()) {
264
- if (op->Type () == kFeedOpType ) {
265
- std::string feed_target_name = op->Output (" Out" )[0 ];
266
- int idx = boost::get<int >(op->GetAttr (" col" ));
267
- SetFeedVariable (scope, *feed_targets[feed_target_name], feed_holder_name,
268
- idx);
269
- }
270
- }
271
-
272
269
if (!has_fetch_ops) {
273
270
// create fetch_holder variable
274
271
auto * fetch_holder = global_block->Var (fetch_holder_name);
@@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
292
289
}
293
290
}
294
291
295
- Run (*copy_program, scope, 0 , create_vars, create_vars);
296
-
297
- // obtain the data of fetch_targets from fetch_holder
298
- for (auto * op : global_block->AllOps ()) {
299
- if (op->Type () == kFetchOpType ) {
300
- std::string fetch_target_name = op->Input (" X" )[0 ];
301
- int idx = boost::get<int >(op->GetAttr (" col" ));
302
- *fetch_targets[fetch_target_name] =
303
- GetFetchVariable (*scope, fetch_holder_name, idx);
304
- }
305
- }
292
+ auto ctx = Prepare (*copy_program, 0 );
293
+ RunPreparedContext (ctx.get (), scope, feed_targets, fetch_targets, create_vars,
294
+ feed_holder_name, fetch_holder_name);
306
295
}
307
296
308
297
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare (
@@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
370
359
}
371
360
}
372
361
362
+ void Executor::RunPreparedContext (
363
+ ExecutorPrepareContext* ctx, Scope* scope,
364
+ std::map<std::string, const LoDTensor*>& feed_targets,
365
+ std::map<std::string, LoDTensor*>& fetch_targets, bool create_vars,
366
+ const std::string& feed_holder_name, const std::string& fetch_holder_name) {
367
+ auto & global_block = ctx->prog_ .Block (ctx->block_id_ );
368
+
369
+ PADDLE_ENFORCE (
370
+ has_feed_operators (global_block, feed_targets, feed_holder_name),
371
+ " Program in ExecutorPrepareContext should has feed_ops." );
372
+ PADDLE_ENFORCE (
373
+ has_fetch_operators (global_block, fetch_targets, fetch_holder_name),
374
+ " Program in the prepared context should has fetch_ops." );
375
+
376
+ // map the data of feed_targets to feed_holder
377
+ for (auto * op : global_block.AllOps ()) {
378
+ if (op->Type () == kFeedOpType ) {
379
+ std::string feed_target_name = op->Output (" Out" )[0 ];
380
+ int idx = boost::get<int >(op->GetAttr (" col" ));
381
+ SetFeedVariable (scope, *feed_targets[feed_target_name], feed_holder_name,
382
+ idx);
383
+ }
384
+ }
385
+
386
+ RunPreparedContext (ctx, scope, create_vars, create_vars);
387
+
388
+ // obtain the data of fetch_targets from fetch_holder
389
+ for (auto * op : global_block.AllOps ()) {
390
+ if (op->Type () == kFetchOpType ) {
391
+ std::string fetch_target_name = op->Input (" X" )[0 ];
392
+ int idx = boost::get<int >(op->GetAttr (" col" ));
393
+ *fetch_targets[fetch_target_name] =
394
+ GetFetchVariable (*scope, fetch_holder_name, idx);
395
+ }
396
+ }
397
+ }
398
+
373
399
} // namespace framework
374
400
} // namespace paddle
0 commit comments