Skip to content

Commit 2c552d4

Browse files
authored
Merge pull request #9630 from Xreki/core_inference_prepare
Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference
2 parents 5a4d932 + 449bdde commit 2c552d4

File tree

6 files changed

+100
-50
lines changed

6 files changed

+100
-50
lines changed

cmake/cblas.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING)
7878
/usr/lib/reference/
7979
)
8080
else()
81-
# Diable the finding of reference cblas under host's system path
81+
# Disable the finding of reference cblas under host's system path
8282
set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include)
8383
set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib)
8484
endif()

paddle/fluid/framework/executor.cc

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name,
8383
if (tensor.memory_size() == 0) {
8484
return;
8585
}
86-
if (tensor.type().hash_code() != typeid(float).hash_code() &&
87-
tensor.type().hash_code() != typeid(double).hash_code()) {
86+
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
87+
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
8888
return;
8989
}
9090
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
@@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
145145
// Return true if the block has feed operators and holder of matching info.
146146
static bool has_feed_operators(
147147
const BlockDesc& block,
148-
std::map<std::string, const LoDTensor*>& feed_targets,
148+
const std::map<std::string, const LoDTensor*>& feed_targets,
149149
const std::string& feed_holder_name) {
150150
size_t feed_count = 0;
151151
for (auto* op : block.AllOps()) {
152152
if (op->Type() == kFeedOpType) {
153153
feed_count++;
154+
// The input variable's name of feed_op should be feed_holder_name.
154155
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
155156
"Input to feed op should be '%s'", feed_holder_name);
156157
std::string feed_target_name = op->Output("Out")[0];
@@ -166,13 +167,15 @@ static bool has_feed_operators(
166167
feed_count, feed_targets.size(),
167168
"The number of feed operators should match 'feed_targets'");
168169

169-
// When feed operator are present, so should be feed_holder
170-
auto var = block.FindVar(feed_holder_name);
171-
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
172-
feed_holder_name);
173-
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
174-
"'%s' variable should be 'FEED_MINIBATCH' type",
175-
feed_holder_name);
170+
if (!feed_holder_name.empty()) {
171+
// When feed operator are present, so should be feed_holder.
172+
auto var = block.FindVar(feed_holder_name);
173+
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
174+
feed_holder_name);
175+
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
176+
"'%s' variable should be 'FEED_MINIBATCH' type",
177+
feed_holder_name);
178+
}
176179
}
177180

178181
return feed_count > 0;
@@ -185,12 +188,14 @@ static bool has_feed_operators(
185188
// and fetch_holder_name. Raise exception when any mismatch is found.
186189
// Return true if the block has fetch operators and holder of matching info.
187190
static bool has_fetch_operators(
188-
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
191+
const BlockDesc& block,
192+
const std::map<std::string, LoDTensor*>& fetch_targets,
189193
const std::string& fetch_holder_name) {
190194
size_t fetch_count = 0;
191195
for (auto* op : block.AllOps()) {
192196
if (op->Type() == kFetchOpType) {
193197
fetch_count++;
198+
// The output variable's name of fetch_op should be fetch_holder_name.
194199
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
195200
"Output of fetch op should be '%s'", fetch_holder_name);
196201
std::string fetch_target_name = op->Input("X")[0];
@@ -206,13 +211,15 @@ static bool has_fetch_operators(
206211
fetch_count, fetch_targets.size(),
207212
"The number of fetch operators should match 'fetch_targets'");
208213

209-
// When fetch operator are present, so should be fetch_holder
210-
auto var = block.FindVar(fetch_holder_name);
211-
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
212-
fetch_holder_name);
213-
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
214-
"'%s' variable should be 'FETCH_LIST' type",
215-
fetch_holder_name);
214+
if (!fetch_holder_name.empty()) {
215+
// When fetch operator are present, so should be fetch_holder.
216+
auto var = block.FindVar(fetch_holder_name);
217+
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
218+
fetch_holder_name);
219+
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
220+
"'%s' variable should be 'FETCH_LIST' type",
221+
fetch_holder_name);
222+
}
216223
}
217224

218225
return fetch_count > 0;
@@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
259266
}
260267
}
261268

262-
// map the data of feed_targets to feed_holder
263-
for (auto* op : global_block->AllOps()) {
264-
if (op->Type() == kFeedOpType) {
265-
std::string feed_target_name = op->Output("Out")[0];
266-
int idx = boost::get<int>(op->GetAttr("col"));
267-
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
268-
idx);
269-
}
270-
}
271-
272269
if (!has_fetch_ops) {
273270
// create fetch_holder variable
274271
auto* fetch_holder = global_block->Var(fetch_holder_name);
@@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
292289
}
293290
}
294291

295-
Run(*copy_program, scope, 0, create_vars, create_vars);
296-
297-
// obtain the data of fetch_targets from fetch_holder
298-
for (auto* op : global_block->AllOps()) {
299-
if (op->Type() == kFetchOpType) {
300-
std::string fetch_target_name = op->Input("X")[0];
301-
int idx = boost::get<int>(op->GetAttr("col"));
302-
*fetch_targets[fetch_target_name] =
303-
GetFetchVariable(*scope, fetch_holder_name, idx);
304-
}
305-
}
292+
auto ctx = Prepare(*copy_program, 0);
293+
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars,
294+
feed_holder_name, fetch_holder_name);
306295
}
307296

308297
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
@@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
370359
}
371360
}
372361

362+
void Executor::RunPreparedContext(
363+
ExecutorPrepareContext* ctx, Scope* scope,
364+
std::map<std::string, const LoDTensor*>& feed_targets,
365+
std::map<std::string, LoDTensor*>& fetch_targets, bool create_vars,
366+
const std::string& feed_holder_name, const std::string& fetch_holder_name) {
367+
auto& global_block = ctx->prog_.Block(ctx->block_id_);
368+
369+
PADDLE_ENFORCE(
370+
has_feed_operators(global_block, feed_targets, feed_holder_name),
371+
"Program in ExecutorPrepareContext should has feed_ops.");
372+
PADDLE_ENFORCE(
373+
has_fetch_operators(global_block, fetch_targets, fetch_holder_name),
374+
"Program in the prepared context should has fetch_ops.");
375+
376+
// map the data of feed_targets to feed_holder
377+
for (auto* op : global_block.AllOps()) {
378+
if (op->Type() == kFeedOpType) {
379+
std::string feed_target_name = op->Output("Out")[0];
380+
int idx = boost::get<int>(op->GetAttr("col"));
381+
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
382+
idx);
383+
}
384+
}
385+
386+
RunPreparedContext(ctx, scope, create_vars, create_vars);
387+
388+
// obtain the data of fetch_targets from fetch_holder
389+
for (auto* op : global_block.AllOps()) {
390+
if (op->Type() == kFetchOpType) {
391+
std::string fetch_target_name = op->Input("X")[0];
392+
int idx = boost::get<int>(op->GetAttr("col"));
393+
*fetch_targets[fetch_target_name] =
394+
GetFetchVariable(*scope, fetch_holder_name, idx);
395+
}
396+
}
397+
}
398+
373399
} // namespace framework
374400
} // namespace paddle

paddle/fluid/framework/executor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <map>
18+
#include <string>
19+
#include <vector>
1720
#include "paddle/fluid/framework/op_info.h"
1821
#include "paddle/fluid/framework/program_desc.h"
1922
#include "paddle/fluid/framework/scope.h"
@@ -70,6 +73,13 @@ class Executor {
7073
bool create_local_scope = true,
7174
bool create_vars = true);
7275

76+
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
77+
std::map<std::string, const LoDTensor*>& feed_targets,
78+
std::map<std::string, LoDTensor*>& fetch_targets,
79+
bool create_vars = true,
80+
const std::string& feed_holder_name = "feed",
81+
const std::string& fetch_holder_name = "fetch");
82+
7383
private:
7484
const platform::Place place_;
7585
};

paddle/fluid/inference/io.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ limitations under the License. */
2323
namespace paddle {
2424
namespace inference {
2525

26-
// Temporarilly add this function for exposing framework::InitDevices() when
26+
// Temporarily add this function for exposing framework::InitDevices() when
2727
// linking the inference shared library.
2828
void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
2929

paddle/fluid/inference/tests/book/test_inference_image_classification.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
4646

4747
// Run inference on CPU
4848
LOG(INFO) << "--- CPU Runs: ---";
49-
TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds,
50-
cpu_fetchs1, FLAGS_repeat);
49+
TestInference<paddle::platform::CPUPlace, false, true>(
50+
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
5151
LOG(INFO) << output1.dims();
5252

5353
#ifdef PADDLE_WITH_CUDA
@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
5757

5858
// Run inference on CUDA GPU
5959
LOG(INFO) << "--- GPU Runs: ---";
60-
TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds,
61-
cpu_fetchs2, FLAGS_repeat);
60+
TestInference<paddle::platform::CUDAPlace, false, true>(
61+
dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
6262
LOG(INFO) << output2.dims();
6363

6464
CheckError<float>(output1, output2);

paddle/fluid/inference/tests/test_helper.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
8989
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
9090
}
9191

92-
template <typename Place, bool CreateVars = true>
92+
template <typename Place, bool CreateVars = true, bool PrepareContext = false>
9393
void TestInference(const std::string& dirname,
9494
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
9595
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
@@ -175,8 +175,15 @@ void TestInference(const std::string& dirname,
175175
}
176176

177177
// Ignore the profiling results of the first run
178-
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
179-
CreateVars);
178+
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
179+
if (PrepareContext) {
180+
ctx = executor.Prepare(*inference_program, 0);
181+
executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
182+
CreateVars);
183+
} else {
184+
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
185+
CreateVars);
186+
}
180187

181188
// Enable the profiler
182189
paddle::platform::EnableProfiler(state);
@@ -187,8 +194,15 @@ void TestInference(const std::string& dirname,
187194
"run_inference",
188195
paddle::platform::DeviceContextPool::Instance().Get(place));
189196

190-
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
191-
CreateVars);
197+
if (PrepareContext) {
198+
// Note: if you change the inference_program, you need to call
199+
// executor.Prepare() again to get a new ExecutorPrepareContext.
200+
executor.RunPreparedContext(ctx.get(), scope, feed_targets,
201+
fetch_targets, CreateVars);
202+
} else {
203+
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
204+
CreateVars);
205+
}
192206
}
193207

194208
// Disable the profiler and print the timing information

0 commit comments

Comments
 (0)