@@ -21,6 +21,7 @@ limitations under the License. */
21
21
#include < utility>
22
22
#include < vector>
23
23
24
+ #include " paddle/fluid/framework/feed_fetch_method.h"
24
25
#include " paddle/fluid/inference/api/api_impl.h"
25
26
#include " paddle/fluid/platform/profiler.h"
26
27
@@ -57,6 +58,25 @@ std::string num2str(T a) {
57
58
}
58
59
} // namespace
59
60
61
+ void NativePaddlePredictor::PrepareFeedFetch () {
62
+ for (auto *op : inference_program_->Block (0 ).AllOps ()) {
63
+ if (op->Type () == " feed" ) {
64
+ int idx = boost::get<int >(op->GetAttr (" col" ));
65
+ if (feeds_.size () <= idx) {
66
+ feeds_.resize (idx + 1 );
67
+ }
68
+ feeds_[idx] = op;
69
+ feed_names_[op->Output (" Out" )[0 ]] = idx;
70
+ } else if (op->Type () == " fetch" ) {
71
+ int idx = boost::get<int >(op->GetAttr (" col" ));
72
+ if (fetchs_.size () <= idx) {
73
+ fetchs_.resize (idx + 1 );
74
+ }
75
+ fetchs_[idx] = op;
76
+ }
77
+ }
78
+ }
79
+
60
80
bool NativePaddlePredictor::Init (
61
81
std::shared_ptr<framework::Scope> parent_scope) {
62
82
VLOG (3 ) << " Predictor::init()" ;
@@ -108,8 +128,7 @@ bool NativePaddlePredictor::Init(
108
128
sub_scope_ ? sub_scope_ : scope_.get (), 0 );
109
129
110
130
// Get the feed_target_names and fetch_target_names
111
- feed_target_names_ = inference_program_->GetFeedTargetNames ();
112
- fetch_target_names_ = inference_program_->GetFetchTargetNames ();
131
+ PrepareFeedFetch ();
113
132
return true ;
114
133
}
115
134
@@ -130,36 +149,21 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
130
149
Timer timer;
131
150
timer.tic ();
132
151
// set feed variable
133
- std::map<std::string, const framework::LoDTensor *> feed_targets;
134
152
std::vector<framework::LoDTensor> feeds;
135
- if (!SetFeed (inputs, &feeds)) {
153
+ framework::Scope *scope = sub_scope_ != nullptr ? sub_scope_ : scope_.get ();
154
+ if (!SetFeed (inputs, scope)) {
136
155
LOG (ERROR) << " fail to set feed" ;
137
156
return false ;
138
157
}
139
- for (size_t i = 0 ; i < feed_target_names_.size (); ++i) {
140
- if (config_.specify_input_name ) {
141
- feed_targets[inputs[i].name ] = &feeds[i];
142
- } else {
143
- feed_targets[feed_target_names_[i]] = &feeds[i];
144
- }
145
- }
146
- // get fetch variable
147
- std::map<std::string, framework::LoDTensor *> fetch_targets;
148
- std::vector<framework::LoDTensor> fetchs;
149
- fetchs.resize (fetch_target_names_.size ());
150
- for (size_t i = 0 ; i < fetch_target_names_.size (); ++i) {
151
- fetch_targets[fetch_target_names_[i]] = &fetchs[i];
152
- }
153
158
// Run the inference program
154
159
// if share variables, we need not create variables
155
160
VLOG (4 ) << " Run prepared context" ;
156
- executor_->RunPreparedContext (
157
- ctx_.get (), sub_scope_ != nullptr ? sub_scope_ : scope_.get (),
158
- &feed_targets, &fetch_targets,
159
- false , /* don't create local scope each time*/
160
- false /* don't create variable eatch time */ );
161
+ executor_->RunPreparedContext (ctx_.get (), scope,
162
+ false , /* don't create local scope each time*/
163
+ false /* don't create variable eatch time */ );
161
164
VLOG (4 ) << " Finish prepared context" ;
162
- if (!GetFetch (fetchs, output_data)) {
165
+ // get fetch variable
166
+ if (!GetFetch (output_data, scope)) {
163
167
LOG (ERROR) << " fail to get fetches" ;
164
168
return false ;
165
169
}
@@ -180,13 +184,13 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
180
184
}
181
185
182
186
bool NativePaddlePredictor::SetFeed (const std::vector<PaddleTensor> &inputs,
183
- std::vector< framework::LoDTensor> *feeds ) {
187
+ framework::Scope *scope ) {
184
188
VLOG (3 ) << " Predictor::set_feed" ;
185
- if (inputs.size () != feed_target_names_ .size ()) {
189
+ if (inputs.size () != feeds_ .size ()) {
186
190
LOG (ERROR) << " wrong feed input size." ;
187
191
return false ;
188
192
}
189
- for (size_t i = 0 ; i < feed_target_names_ .size (); ++i) {
193
+ for (size_t i = 0 ; i < inputs .size (); ++i) {
190
194
framework::LoDTensor input;
191
195
framework::DDim ddim = framework::make_ddim (inputs[i].shape );
192
196
void *input_ptr;
@@ -208,29 +212,38 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
208
212
lod.emplace_back (level);
209
213
}
210
214
input.set_lod (lod);
211
-
212
- feeds->push_back (input);
215
+ int idx = -1 ;
216
+ if (config_.specify_input_name ) {
217
+ idx = feed_names_[inputs[i].name ];
218
+ } else {
219
+ idx = boost::get<int >(feeds_[i]->GetAttr (" col" ));
220
+ }
221
+ framework::SetFeedVariable (scope, input, " feed" , idx);
213
222
}
214
223
return true ;
215
224
}
216
225
217
- bool NativePaddlePredictor::GetFetch (
218
- const std::vector<framework::LoDTensor> &fetchs,
219
- std::vector<PaddleTensor> *outputs) {
226
+ bool NativePaddlePredictor::GetFetch (std::vector<PaddleTensor> *outputs,
227
+ framework::Scope *scope) {
220
228
VLOG (3 ) << " Predictor::get_fetch" ;
221
- outputs->resize (fetchs.size ());
222
- for (size_t i = 0 ; i < fetchs.size (); ++i) {
229
+ outputs->resize (fetchs_.size ());
230
+ for (size_t i = 0 ; i < fetchs_.size (); ++i) {
231
+ int idx = boost::get<int >(fetchs_[i]->GetAttr (" col" ));
232
+ PADDLE_ENFORCE (idx == i);
233
+ framework::LoDTensor &output =
234
+ framework::GetFetchVariable (*scope, " fetch" , idx);
223
235
// TODO(panyx0718): Support fetch of other types.
224
- if (fetchs[i] .type () != typeid (float )) {
236
+ if (output .type () != typeid (float )) {
225
237
LOG (ERROR) << " only support fetching float now." ;
226
238
return false ;
227
239
}
240
+
228
241
std::vector<int > shape;
229
- auto dims_i = fetchs[i] .dims ();
230
- auto lod = fetchs[i] .lod ();
231
- const float *output_ptr = fetchs[i] .data <float >();
242
+ auto dims_i = output .dims ();
243
+ auto lod = output .lod ();
244
+ const float *output_ptr = output .data <float >();
232
245
// const int64_t* output_ptr = fetchs[i].data<int64_t>();
233
- auto num = fetchs[i] .numel ();
246
+ auto num = output .numel ();
234
247
std::vector<float > data;
235
248
if (0 == lod.size ()) {
236
249
std::copy (output_ptr, output_ptr + num, std::back_inserter (data));
@@ -275,7 +288,7 @@ bool NativePaddlePredictor::GetFetch(
275
288
}
276
289
std::memcpy (buffer.data (), data.data (), buffer.length ());
277
290
// copy LoD
278
- for (const auto &level : fetchs[i] .lod ()) {
291
+ for (const auto &level : output .lod ()) {
279
292
outputs->at (i).lod .emplace_back (level);
280
293
}
281
294
outputs->at (i).dtype = PaddleDType::FLOAT32;
0 commit comments