Skip to content

Commit 823c4f8

Browse files
authored
Merge pull request #13058 from panyx0718/infer
use fast RunPrepareContext for inference
2 parents 7cb6fe7 + 5adf118 commit 823c4f8

File tree

6 files changed

+91
-52
lines changed

6 files changed

+91
-52
lines changed

paddle/fluid/framework/executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class Executor {
6060
void Run(const ProgramDesc& prog, Scope* scope, int block_id,
6161
bool create_local_scope = true, bool create_vars = true);
6262

63+
// This API is very slow.
6364
void Run(const ProgramDesc& program, Scope* scope,
6465
std::map<std::string, const LoDTensor*>* feed_targets,
6566
std::map<std::string, LoDTensor*>* fetch_targets,
@@ -79,6 +80,7 @@ class Executor {
7980
bool create_local_scope = true,
8081
bool create_vars = true, bool keep_kids = false);
8182

83+
// This API is very slow.
8284
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
8385
std::map<std::string, const LoDTensor*>* feed_targets,
8486
std::map<std::string, LoDTensor*>* fetch_targets,

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ class AnalysisPredictor : public NativePaddlePredictor {
8080
sub_scope_ ? sub_scope_ : scope_.get(), 0);
8181

8282
// Get the feed_target_names and fetch_target_names
83-
feed_target_names_ = inference_program_->GetFeedTargetNames();
84-
fetch_target_names_ = inference_program_->GetFetchTargetNames();
83+
PrepareFeedFetch();
8584
return true;
8685
}
8786

paddle/fluid/inference/api/api_impl.cc

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include <utility>
2222
#include <vector>
2323

24+
#include "paddle/fluid/framework/feed_fetch_method.h"
2425
#include "paddle/fluid/inference/api/api_impl.h"
2526
#include "paddle/fluid/platform/profiler.h"
2627

@@ -57,6 +58,25 @@ std::string num2str(T a) {
5758
}
5859
} // namespace
5960

61+
void NativePaddlePredictor::PrepareFeedFetch() {
62+
for (auto *op : inference_program_->Block(0).AllOps()) {
63+
if (op->Type() == "feed") {
64+
int idx = boost::get<int>(op->GetAttr("col"));
65+
if (feeds_.size() <= idx) {
66+
feeds_.resize(idx + 1);
67+
}
68+
feeds_[idx] = op;
69+
feed_names_[op->Output("Out")[0]] = idx;
70+
} else if (op->Type() == "fetch") {
71+
int idx = boost::get<int>(op->GetAttr("col"));
72+
if (fetchs_.size() <= idx) {
73+
fetchs_.resize(idx + 1);
74+
}
75+
fetchs_[idx] = op;
76+
}
77+
}
78+
}
79+
6080
bool NativePaddlePredictor::Init(
6181
std::shared_ptr<framework::Scope> parent_scope) {
6282
VLOG(3) << "Predictor::init()";
@@ -108,8 +128,7 @@ bool NativePaddlePredictor::Init(
108128
sub_scope_ ? sub_scope_ : scope_.get(), 0);
109129

110130
// Get the feed_target_names and fetch_target_names
111-
feed_target_names_ = inference_program_->GetFeedTargetNames();
112-
fetch_target_names_ = inference_program_->GetFetchTargetNames();
131+
PrepareFeedFetch();
113132
return true;
114133
}
115134

@@ -130,36 +149,21 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
130149
Timer timer;
131150
timer.tic();
132151
// set feed variable
133-
std::map<std::string, const framework::LoDTensor *> feed_targets;
134152
std::vector<framework::LoDTensor> feeds;
135-
if (!SetFeed(inputs, &feeds)) {
153+
framework::Scope *scope = sub_scope_ != nullptr ? sub_scope_ : scope_.get();
154+
if (!SetFeed(inputs, scope)) {
136155
LOG(ERROR) << "fail to set feed";
137156
return false;
138157
}
139-
for (size_t i = 0; i < feed_target_names_.size(); ++i) {
140-
if (config_.specify_input_name) {
141-
feed_targets[inputs[i].name] = &feeds[i];
142-
} else {
143-
feed_targets[feed_target_names_[i]] = &feeds[i];
144-
}
145-
}
146-
// get fetch variable
147-
std::map<std::string, framework::LoDTensor *> fetch_targets;
148-
std::vector<framework::LoDTensor> fetchs;
149-
fetchs.resize(fetch_target_names_.size());
150-
for (size_t i = 0; i < fetch_target_names_.size(); ++i) {
151-
fetch_targets[fetch_target_names_[i]] = &fetchs[i];
152-
}
153158
// Run the inference program
154159
// if share variables, we need not create variables
155160
VLOG(4) << "Run prepared context";
156-
executor_->RunPreparedContext(
157-
ctx_.get(), sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
158-
&feed_targets, &fetch_targets,
159-
false, /* don't create local scope each time*/
160-
false /* don't create variable eatch time */);
161+
executor_->RunPreparedContext(ctx_.get(), scope,
162+
false, /* don't create local scope each time*/
163+
false /* don't create variable eatch time */);
161164
VLOG(4) << "Finish prepared context";
162-
if (!GetFetch(fetchs, output_data)) {
165+
// get fetch variable
166+
if (!GetFetch(output_data, scope)) {
163167
LOG(ERROR) << "fail to get fetches";
164168
return false;
165169
}
@@ -180,13 +184,13 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
180184
}
181185

182186
bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
183-
std::vector<framework::LoDTensor> *feeds) {
187+
framework::Scope *scope) {
184188
VLOG(3) << "Predictor::set_feed";
185-
if (inputs.size() != feed_target_names_.size()) {
189+
if (inputs.size() != feeds_.size()) {
186190
LOG(ERROR) << "wrong feed input size.";
187191
return false;
188192
}
189-
for (size_t i = 0; i < feed_target_names_.size(); ++i) {
193+
for (size_t i = 0; i < inputs.size(); ++i) {
190194
framework::LoDTensor input;
191195
framework::DDim ddim = framework::make_ddim(inputs[i].shape);
192196
void *input_ptr;
@@ -208,29 +212,38 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
208212
lod.emplace_back(level);
209213
}
210214
input.set_lod(lod);
211-
212-
feeds->push_back(input);
215+
int idx = -1;
216+
if (config_.specify_input_name) {
217+
idx = feed_names_[inputs[i].name];
218+
} else {
219+
idx = boost::get<int>(feeds_[i]->GetAttr("col"));
220+
}
221+
framework::SetFeedVariable(scope, input, "feed", idx);
213222
}
214223
return true;
215224
}
216225

217-
bool NativePaddlePredictor::GetFetch(
218-
const std::vector<framework::LoDTensor> &fetchs,
219-
std::vector<PaddleTensor> *outputs) {
226+
bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
227+
framework::Scope *scope) {
220228
VLOG(3) << "Predictor::get_fetch";
221-
outputs->resize(fetchs.size());
222-
for (size_t i = 0; i < fetchs.size(); ++i) {
229+
outputs->resize(fetchs_.size());
230+
for (size_t i = 0; i < fetchs_.size(); ++i) {
231+
int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
232+
PADDLE_ENFORCE(idx == i);
233+
framework::LoDTensor &output =
234+
framework::GetFetchVariable(*scope, "fetch", idx);
223235
// TODO(panyx0718): Support fetch of other types.
224-
if (fetchs[i].type() != typeid(float)) {
236+
if (output.type() != typeid(float)) {
225237
LOG(ERROR) << "only support fetching float now.";
226238
return false;
227239
}
240+
228241
std::vector<int> shape;
229-
auto dims_i = fetchs[i].dims();
230-
auto lod = fetchs[i].lod();
231-
const float *output_ptr = fetchs[i].data<float>();
242+
auto dims_i = output.dims();
243+
auto lod = output.lod();
244+
const float *output_ptr = output.data<float>();
232245
// const int64_t* output_ptr = fetchs[i].data<int64_t>();
233-
auto num = fetchs[i].numel();
246+
auto num = output.numel();
234247
std::vector<float> data;
235248
if (0 == lod.size()) {
236249
std::copy(output_ptr, output_ptr + num, std::back_inserter(data));
@@ -275,7 +288,7 @@ bool NativePaddlePredictor::GetFetch(
275288
}
276289
std::memcpy(buffer.data(), data.data(), buffer.length());
277290
// copy LoD
278-
for (const auto &level : fetchs[i].lod()) {
291+
for (const auto &level : output.lod()) {
279292
outputs->at(i).lod.emplace_back(level);
280293
}
281294
outputs->at(i).dtype = PaddleDType::FLOAT32;

paddle/fluid/inference/api/api_impl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <glog/logging.h>
18+
#include <map>
1819
#include <memory>
1920
#include <string>
2021
#include <vector>
@@ -47,18 +48,21 @@ class NativePaddlePredictor : public PaddlePredictor {
4748

4849
protected:
4950
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
50-
std::vector<framework::LoDTensor> *feeds);
51-
bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
52-
std::vector<PaddleTensor> *output_data);
51+
framework::Scope *scope);
52+
bool GetFetch(std::vector<PaddleTensor> *output_data,
53+
framework::Scope *scope);
54+
55+
void PrepareFeedFetch();
5356

5457
NativeConfig config_;
5558
platform::Place place_;
5659
std::unique_ptr<framework::Executor> executor_;
5760
std::shared_ptr<framework::Scope> scope_;
5861
std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
5962
std::unique_ptr<framework::ProgramDesc> inference_program_;
60-
std::vector<std::string> feed_target_names_;
61-
std::vector<std::string> fetch_target_names_;
63+
std::vector<framework::OpDesc *> feeds_;
64+
std::map<std::string, size_t> feed_names_;
65+
std::vector<framework::OpDesc *> fetchs_;
6266
// Do not use unique_ptr, use parent scope to delete
6367
framework::Scope *sub_scope_{nullptr};
6468
};

paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,8 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
7474
VLOG(5) << "to create variables";
7575
executor_->CreateVariables(*inference_program_,
7676
sub_scope_ ? sub_scope_ : scope_.get(), 0);
77-
7877
// Get the feed_target_names and fetch_target_names
79-
feed_target_names_ = inference_program_->GetFeedTargetNames();
80-
fetch_target_names_ = inference_program_->GetFetchTargetNames();
78+
PrepareFeedFetch();
8179
return true;
8280
}
8381

paddle/fluid/inference/tests/book/test_inference_nlp.cc

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ limitations under the License. */
2121
#include "paddle/fluid/inference/tests/test_helper.h"
2222
#include "paddle/fluid/platform/cpu_helper.h"
2323

24+
#include "paddle/fluid/framework/feed_fetch_method.h"
25+
2426
DEFINE_string(model_path, "", "Directory of the inference model.");
2527
DEFINE_string(data_file, "", "File of input index data.");
2628
DEFINE_int32(repeat, 100, "Running the inference program repeat times");
@@ -124,14 +126,35 @@ void ThreadRunInfer(
124126
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
125127
PADDLE_ENFORCE_EQ(feed_target_names.size(), 1UL);
126128

129+
// map the data of feed_targets to feed_holder
130+
for (auto* op : inference_program->Block(0).AllOps()) {
131+
if (op->Type() == "feed") {
132+
std::string feed_target_name = op->Output("Out")[0];
133+
int idx = boost::get<int>(op->GetAttr("col"));
134+
paddle::framework::SetFeedVariable(scope, *feed_targets[feed_target_name],
135+
"feed", idx);
136+
}
137+
}
138+
127139
auto& inputs = jobs[tid];
128140
auto start_ms = GetCurrentMs();
129141
for (size_t i = 0; i < inputs.size(); ++i) {
130142
feed_targets[feed_target_names[0]] = inputs[i];
131-
executor.RunPreparedContext(ctx.get(), &sub_scope, &feed_targets,
132-
&fetch_targets, false /*create_local_scope*/);
143+
executor.RunPreparedContext(ctx.get(), &sub_scope,
144+
false /*create_local_scope*/);
133145
}
134146
auto stop_ms = GetCurrentMs();
147+
148+
// obtain the data of fetch_targets from fetch_holder
149+
for (auto* op : inference_program->Block(0).AllOps()) {
150+
if (op->Type() == "fetch") {
151+
std::string fetch_target_name = op->Input("X")[0];
152+
int idx = boost::get<int>(op->GetAttr("col"));
153+
*fetch_targets[fetch_target_name] =
154+
paddle::framework::GetFetchVariable(*scope, "fetch", idx);
155+
}
156+
}
157+
135158
scope->DeleteScope(&sub_scope);
136159
LOG(INFO) << "Tid: " << tid << ", process " << inputs.size()
137160
<< " samples, avg time per sample: "

0 commit comments

Comments
 (0)