Skip to content

Commit fbd3604

Browse files
committed
Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference.
1 parent 172c887 commit fbd3604

File tree

4 files changed

+85
-40
lines changed

4 files changed

+85
-40
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ static bool has_feed_operators(
129129
feed_count, feed_targets.size(),
130130
"The number of feed operators should match 'feed_targets'");
131131

132-
// When feed operator are present, so should be feed_holder
133-
auto var = block.FindVar(feed_holder_name);
134-
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
135-
feed_holder_name);
136-
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
137-
"'%s' variable should be 'FEED_MINIBATCH' type",
138-
feed_holder_name);
132+
if (!feed_holder_name.empty()) {
133+
// When feed operator are present, so should be feed_holder
134+
auto var = block.FindVar(feed_holder_name);
135+
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
136+
feed_holder_name);
137+
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
138+
"'%s' variable should be 'FEED_MINIBATCH' type",
139+
feed_holder_name);
140+
}
139141
}
140142

141143
return feed_count > 0;
@@ -169,13 +171,15 @@ static bool has_fetch_operators(
169171
fetch_count, fetch_targets.size(),
170172
"The number of fetch operators should match 'fetch_targets'");
171173

172-
// When fetch operator are present, so should be fetch_holder
173-
auto var = block.FindVar(fetch_holder_name);
174-
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
175-
fetch_holder_name);
176-
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
177-
"'%s' variable should be 'FETCH_LIST' type",
178-
fetch_holder_name);
174+
if (!fetch_holder_name.empty()) {
175+
// When fetch operator are present, so should be fetch_holder
176+
auto var = block.FindVar(fetch_holder_name);
177+
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
178+
fetch_holder_name);
179+
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
180+
"'%s' variable should be 'FETCH_LIST' type",
181+
fetch_holder_name);
182+
}
179183
}
180184

181185
return fetch_count > 0;
@@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
222226
}
223227
}
224228

225-
// map the data of feed_targets to feed_holder
226-
for (auto* op : global_block->AllOps()) {
227-
if (op->Type() == kFeedOpType) {
228-
std::string feed_target_name = op->Output("Out")[0];
229-
int idx = boost::get<int>(op->GetAttr("col"));
230-
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
231-
idx);
232-
}
233-
}
234-
235229
if (!has_fetch_ops) {
236230
// create fetch_holder variable
237231
auto* fetch_holder = global_block->Var(fetch_holder_name);
@@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
255249
}
256250
}
257251

258-
Run(*copy_program, scope, 0, create_vars, create_vars);
259-
260-
// obtain the data of fetch_targets from fetch_holder
261-
for (auto* op : global_block->AllOps()) {
262-
if (op->Type() == kFetchOpType) {
263-
std::string fetch_target_name = op->Input("X")[0];
264-
int idx = boost::get<int>(op->GetAttr("col"));
265-
*fetch_targets[fetch_target_name] =
266-
GetFetchVariable(*scope, fetch_holder_name, idx);
267-
}
268-
}
252+
auto ctx = Prepare(*copy_program, 0);
253+
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
254+
feed_holder_name, fetch_holder_name, create_vars);
269255
}
270256

271257
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
@@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
343329
}
344330
}
345331

332+
void Executor::RunPreparedContext(
333+
ExecutorPrepareContext* ctx, Scope* scope,
334+
std::map<std::string, const LoDTensor*>& feed_targets,
335+
std::map<std::string, LoDTensor*>& fetch_targets,
336+
const std::string& feed_holder_name, const std::string& fetch_holder_name,
337+
bool create_vars) {
338+
auto& global_block = ctx->prog_.Block(ctx->block_id_);
339+
340+
// map the data of feed_targets to feed_holder
341+
for (auto* op : global_block.AllOps()) {
342+
if (op->Type() == kFeedOpType) {
343+
std::string feed_target_name = op->Output("Out")[0];
344+
PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(),
345+
"Variable %s is not feeded.");
346+
347+
int idx = boost::get<int>(op->GetAttr("col"));
348+
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
349+
idx);
350+
}
351+
}
352+
353+
RunPreparedContext(ctx, scope, create_vars, create_vars);
354+
355+
// obtain the data of fetch_targets from fetch_holder
356+
for (auto* op : global_block.AllOps()) {
357+
if (op->Type() == kFetchOpType) {
358+
std::string fetch_target_name = op->Input("X")[0];
359+
PADDLE_ENFORCE(
360+
fetch_targets.find(fetch_target_name) != fetch_targets.end(),
361+
"Variable %s is not fetched.");
362+
363+
int idx = boost::get<int>(op->GetAttr("col"));
364+
*fetch_targets[fetch_target_name] =
365+
GetFetchVariable(*scope, fetch_holder_name, idx);
366+
}
367+
}
368+
}
369+
346370
} // namespace framework
347371
} // namespace paddle

paddle/fluid/framework/executor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ class Executor {
6565
bool create_local_scope = true,
6666
bool create_vars = true);
6767

68+
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
69+
std::map<std::string, const LoDTensor*>& feed_targets,
70+
std::map<std::string, LoDTensor*>& fetch_targets,
71+
const std::string& feed_holder_name = "feed",
72+
const std::string& fetch_holder_name = "fetch",
73+
bool create_vars = true);
74+
6875
private:
6976
const platform::Place place_;
7077
};

paddle/fluid/inference/tests/book/test_inference_image_classification.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ TEST(inference, image_classification) {
4848

4949
// Run inference on CPU
5050
LOG(INFO) << "--- CPU Runs: ---";
51-
TestInference<paddle::platform::CPUPlace>(
51+
TestInference<paddle::platform::CPUPlace, true>(
5252
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
5353
LOG(INFO) << output1.dims();
5454

@@ -59,7 +59,7 @@ TEST(inference, image_classification) {
5959

6060
// Run inference on CUDA GPU
6161
LOG(INFO) << "--- GPU Runs: ---";
62-
TestInference<paddle::platform::CUDAPlace>(
62+
TestInference<paddle::platform::CUDAPlace, true>(
6363
dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
6464
LOG(INFO) << output2.dims();
6565

paddle/fluid/inference/tests/test_helper.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ void CheckError(paddle::framework::LoDTensor& output1,
8888
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
8989
}
9090

91-
template <typename Place>
91+
template <typename Place, bool PrepareContext = false>
9292
void TestInference(const std::string& dirname,
9393
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
9494
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
@@ -170,7 +170,14 @@ void TestInference(const std::string& dirname,
170170
// 6. Run the inference program
171171
{
172172
// Ignore the profiling results of the first run
173-
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
173+
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
174+
if (PrepareContext) {
175+
ctx = executor.Prepare(*inference_program, 0);
176+
executor.RunPreparedContext(
177+
ctx.get(), scope, feed_targets, fetch_targets);
178+
} else {
179+
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
180+
}
174181

175182
// Enable the profiler
176183
paddle::platform::EnableProfiler(state);
@@ -181,7 +188,14 @@ void TestInference(const std::string& dirname,
181188
"run_inference",
182189
paddle::platform::DeviceContextPool::Instance().Get(place));
183190

184-
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
191+
if (PrepareContext) {
192+
// Note: if you changed the inference_program, you need to call
193+
// executor.Prepare() again to get a new ExecutorPrepareContext.
194+
executor.RunPreparedContext(
195+
ctx.get(), scope, feed_targets, fetch_targets);
196+
} else {
197+
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
198+
}
185199
}
186200

187201
// Disable the profiler and print the timing information

0 commit comments

Comments
 (0)