Skip to content

Commit 4794d9c

Browse files
committed
use fast RunPrepareContext for inference
1 parent 515a756 commit 4794d9c

File tree

4 files changed

+92
-50
lines changed

4 files changed

+92
-50
lines changed

paddle/fluid/inference/api/api_impl.cc

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include <utility>
2222
#include <vector>
2323

24+
#include "paddle/fluid/framework/feed_fetch_method.h"
2425
#include "paddle/fluid/inference/api/api_impl.h"
2526
#include "paddle/fluid/platform/profiler.h"
2627

@@ -57,6 +58,27 @@ std::string num2str(T a) {
5758
}
5859
} // namespace
5960

61+
void NativePaddlePredictor::PrepareFeedFetch() {
62+
for (auto *op : inference_program_->Block(0).AllOps()) {
63+
if (op->Type() == "feed") {
64+
int idx = boost::get<int>(op->GetAttr("col"));
65+
if (feeds_.size() <= idx) {
66+
feeds_.resize(idx + 1);
67+
}
68+
feeds_[idx] = op;
69+
feed_names_[op->Output("Out")[0]] = idx;
70+
LOG(ERROR) << "feed " << idx << " " << op->Output("Out")[0];
71+
} else if (op->Type() == "fetch") {
72+
int idx = boost::get<int>(op->GetAttr("col"));
73+
if (fetchs_.size() <= idx) {
74+
fetchs_.resize(idx + 1);
75+
}
76+
fetchs_[idx] = op;
77+
LOG(ERROR) << "fetch " << idx << " " << op->Input("X")[0];
78+
}
79+
}
80+
}
81+
6082
bool NativePaddlePredictor::Init(
6183
std::shared_ptr<framework::Scope> parent_scope) {
6284
VLOG(3) << "Predictor::init()";
@@ -108,8 +130,7 @@ bool NativePaddlePredictor::Init(
108130
sub_scope_ ? sub_scope_ : scope_.get(), 0);
109131

110132
// Get the feed_target_names and fetch_target_names
111-
feed_target_names_ = inference_program_->GetFeedTargetNames();
112-
fetch_target_names_ = inference_program_->GetFetchTargetNames();
133+
PrepareFeedFetch();
113134
return true;
114135
}
115136

@@ -130,36 +151,21 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
130151
Timer timer;
131152
timer.tic();
132153
// set feed variable
133-
std::map<std::string, const framework::LoDTensor *> feed_targets;
134154
std::vector<framework::LoDTensor> feeds;
135-
if (!SetFeed(inputs, &feeds)) {
155+
framework::Scope *scope = sub_scope_ != nullptr ? sub_scope_ : scope_.get();
156+
if (!SetFeed(inputs, scope)) {
136157
LOG(ERROR) << "fail to set feed";
137158
return false;
138159
}
139-
for (size_t i = 0; i < feed_target_names_.size(); ++i) {
140-
if (config_.specify_input_name) {
141-
feed_targets[inputs[i].name] = &feeds[i];
142-
} else {
143-
feed_targets[feed_target_names_[i]] = &feeds[i];
144-
}
145-
}
146-
// get fetch variable
147-
std::map<std::string, framework::LoDTensor *> fetch_targets;
148-
std::vector<framework::LoDTensor> fetchs;
149-
fetchs.resize(fetch_target_names_.size());
150-
for (size_t i = 0; i < fetch_target_names_.size(); ++i) {
151-
fetch_targets[fetch_target_names_[i]] = &fetchs[i];
152-
}
153160
// Run the inference program
154161
// if share variables, we need not create variables
155162
VLOG(4) << "Run prepared context";
156-
executor_->RunPreparedContext(
157-
ctx_.get(), sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
158-
&feed_targets, &fetch_targets,
159-
false, /* don't create local scope each time*/
160-
false /* don't create variable eatch time */);
163+
executor_->RunPreparedContext(ctx_.get(), scope,
164+
false, /* don't create local scope each time*/
165+
false /* don't create variable eatch time */);
161166
VLOG(4) << "Finish prepared context";
162-
if (!GetFetch(fetchs, output_data)) {
167+
// get fetch variable
168+
if (!GetFetch(output_data, scope)) {
163169
LOG(ERROR) << "fail to get fetches";
164170
return false;
165171
}
@@ -180,13 +186,13 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
180186
}
181187

182188
bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
183-
std::vector<framework::LoDTensor> *feeds) {
189+
framework::Scope *scope) {
184190
VLOG(3) << "Predictor::set_feed";
185-
if (inputs.size() != feed_target_names_.size()) {
191+
if (inputs.size() != feeds_.size()) {
186192
LOG(ERROR) << "wrong feed input size.";
187193
return false;
188194
}
189-
for (size_t i = 0; i < feed_target_names_.size(); ++i) {
195+
for (size_t i = 0; i < inputs.size(); ++i) {
190196
framework::LoDTensor input;
191197
framework::DDim ddim = framework::make_ddim(inputs[i].shape);
192198
void *input_ptr;
@@ -208,29 +214,40 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
208214
lod.emplace_back(level);
209215
}
210216
input.set_lod(lod);
211-
212-
feeds->push_back(input);
217+
int idx = -1;
218+
if (config_.specify_input_name) {
219+
idx =
220+
boost::get<int>(feeds_[feed_names_[inputs[i].name]]->GetAttr("col"));
221+
} else {
222+
idx = boost::get<int>(feeds_[i]->GetAttr("col"));
223+
}
224+
framework::SetFeedVariable(scope, input, "feed", idx);
213225
}
214226
return true;
215227
}
216228

217-
bool NativePaddlePredictor::GetFetch(
218-
const std::vector<framework::LoDTensor> &fetchs,
219-
std::vector<PaddleTensor> *outputs) {
229+
bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
230+
framework::Scope *scope) {
220231
VLOG(3) << "Predictor::get_fetch";
221-
outputs->resize(fetchs.size());
222-
for (size_t i = 0; i < fetchs.size(); ++i) {
232+
outputs->resize(fetchs_.size());
233+
for (size_t i = 0; i < fetchs_.size(); ++i) {
234+
std::string fetch_target_name = fetchs_[i]->Input("X")[0];
235+
int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
236+
PADDLE_ENFORCE(idx == i);
237+
framework::LoDTensor &output =
238+
framework::GetFetchVariable(*scope, "fetch", idx);
223239
// TODO(panyx0718): Support fetch of other types.
224-
if (fetchs[i].type() != typeid(float)) {
240+
if (output.type() != typeid(float)) {
225241
LOG(ERROR) << "only support fetching float now.";
226242
return false;
227243
}
244+
228245
std::vector<int> shape;
229-
auto dims_i = fetchs[i].dims();
230-
auto lod = fetchs[i].lod();
231-
const float *output_ptr = fetchs[i].data<float>();
246+
auto dims_i = output.dims();
247+
auto lod = output.lod();
248+
const float *output_ptr = output.data<float>();
232249
// const int64_t* output_ptr = fetchs[i].data<int64_t>();
233-
auto num = fetchs[i].numel();
250+
auto num = output.numel();
234251
std::vector<float> data;
235252
if (0 == lod.size()) {
236253
std::copy(output_ptr, output_ptr + num, std::back_inserter(data));
@@ -275,7 +292,7 @@ bool NativePaddlePredictor::GetFetch(
275292
}
276293
std::memcpy(buffer.data(), data.data(), buffer.length());
277294
// copy LoD
278-
for (const auto &level : fetchs[i].lod()) {
295+
for (const auto &level : output.lod()) {
279296
outputs->at(i).lod.emplace_back(level);
280297
}
281298
outputs->at(i).dtype = PaddleDType::FLOAT32;

paddle/fluid/inference/api/api_impl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <glog/logging.h>
18+
#include <map>
1819
#include <memory>
1920
#include <string>
2021
#include <vector>
@@ -47,18 +48,21 @@ class NativePaddlePredictor : public PaddlePredictor {
4748

4849
protected:
4950
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
50-
std::vector<framework::LoDTensor> *feeds);
51-
bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
52-
std::vector<PaddleTensor> *output_data);
51+
framework::Scope *scope);
52+
bool GetFetch(std::vector<PaddleTensor> *output_data,
53+
framework::Scope *scope);
54+
55+
void PrepareFeedFetch();
5356

5457
NativeConfig config_;
5558
platform::Place place_;
5659
std::unique_ptr<framework::Executor> executor_;
5760
std::shared_ptr<framework::Scope> scope_;
5861
std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
5962
std::unique_ptr<framework::ProgramDesc> inference_program_;
60-
std::vector<std::string> feed_target_names_;
61-
std::vector<std::string> fetch_target_names_;
63+
std::vector<framework::OpDesc *> feeds_;
64+
std::map<std::string, size_t> feed_names_;
65+
std::vector<framework::OpDesc *> fetchs_;
6266
// Do not use unique_ptr, use parent scope to delete
6367
framework::Scope *sub_scope_{nullptr};
6468
};

paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,8 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
7474
VLOG(5) << "to create variables";
7575
executor_->CreateVariables(*inference_program_,
7676
sub_scope_ ? sub_scope_ : scope_.get(), 0);
77-
7877
// Get the feed_target_names and fetch_target_names
79-
feed_target_names_ = inference_program_->GetFeedTargetNames();
80-
fetch_target_names_ = inference_program_->GetFetchTargetNames();
78+
PrepareFeedFetch();
8179
return true;
8280
}
8381

paddle/fluid/inference/tests/book/test_inference_nlp.cc

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ limitations under the License. */
2121
#include "paddle/fluid/inference/tests/test_helper.h"
2222
#include "paddle/fluid/platform/cpu_helper.h"
2323

24+
#include "paddle/fluid/framework/feed_fetch_method.h"
25+
2426
DEFINE_string(model_path, "", "Directory of the inference model.");
2527
DEFINE_string(data_file, "", "File of input index data.");
2628
DEFINE_int32(repeat, 100, "Running the inference program repeat times");
@@ -124,14 +126,35 @@ void ThreadRunInfer(
124126
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
125127
PADDLE_ENFORCE_EQ(feed_target_names.size(), 1UL);
126128

129+
// map the data of feed_targets to feed_holder
130+
for (auto* op : inference_program->Block(0).AllOps()) {
131+
if (op->Type() == "feed") {
132+
std::string feed_target_name = op->Output("Out")[0];
133+
int idx = boost::get<int>(op->GetAttr("col"));
134+
paddle::framework::SetFeedVariable(scope, *feed_targets[feed_target_name],
135+
"feed", idx);
136+
}
137+
}
138+
127139
auto& inputs = jobs[tid];
128140
auto start_ms = GetCurrentMs();
129141
for (size_t i = 0; i < inputs.size(); ++i) {
130142
feed_targets[feed_target_names[0]] = inputs[i];
131-
executor.RunPreparedContext(ctx.get(), &sub_scope, &feed_targets,
132-
&fetch_targets, false /*create_local_scope*/);
143+
executor.RunPreparedContext(ctx.get(), &sub_scope,
144+
false /*create_local_scope*/);
133145
}
134146
auto stop_ms = GetCurrentMs();
147+
148+
// obtain the data of fetch_targets from fetch_holder
149+
for (auto* op : inference_program->Block(0).AllOps()) {
150+
if (op->Type() == "fetch") {
151+
std::string fetch_target_name = op->Input("X")[0];
152+
int idx = boost::get<int>(op->GetAttr("col"));
153+
*fetch_targets[fetch_target_name] =
154+
paddle::framework::GetFetchVariable(*scope, "fetch", idx);
155+
}
156+
}
157+
135158
scope->DeleteScope(&sub_scope);
136159
LOG(INFO) << "Tid: " << tid << ", process " << inputs.size()
137160
<< " samples, avg time per sample: "

0 commit comments

Comments
 (0)