@@ -21,6 +21,7 @@ limitations under the License. */
21
21
#include < utility>
22
22
#include < vector>
23
23
24
+ #include " paddle/fluid/framework/feed_fetch_method.h"
24
25
#include " paddle/fluid/inference/api/api_impl.h"
25
26
#include " paddle/fluid/platform/profiler.h"
26
27
@@ -57,6 +58,27 @@ std::string num2str(T a) {
57
58
}
58
59
} // namespace
59
60
61
+ void NativePaddlePredictor::PrepareFeedFetch () {
62
+ for (auto *op : inference_program_->Block (0 ).AllOps ()) {
63
+ if (op->Type () == " feed" ) {
64
+ int idx = boost::get<int >(op->GetAttr (" col" ));
65
+ if (feeds_.size () <= idx) {
66
+ feeds_.resize (idx + 1 );
67
+ }
68
+ feeds_[idx] = op;
69
+ feed_names_[op->Output (" Out" )[0 ]] = idx;
70
+ LOG (ERROR) << " feed " << idx << " " << op->Output (" Out" )[0 ];
71
+ } else if (op->Type () == " fetch" ) {
72
+ int idx = boost::get<int >(op->GetAttr (" col" ));
73
+ if (fetchs_.size () <= idx) {
74
+ fetchs_.resize (idx + 1 );
75
+ }
76
+ fetchs_[idx] = op;
77
+ LOG (ERROR) << " fetch " << idx << " " << op->Input (" X" )[0 ];
78
+ }
79
+ }
80
+ }
81
+
60
82
bool NativePaddlePredictor::Init (
61
83
std::shared_ptr<framework::Scope> parent_scope) {
62
84
VLOG (3 ) << " Predictor::init()" ;
@@ -108,8 +130,7 @@ bool NativePaddlePredictor::Init(
108
130
sub_scope_ ? sub_scope_ : scope_.get (), 0 );
109
131
110
132
// Get the feed_target_names and fetch_target_names
111
- feed_target_names_ = inference_program_->GetFeedTargetNames ();
112
- fetch_target_names_ = inference_program_->GetFetchTargetNames ();
133
+ PrepareFeedFetch ();
113
134
return true ;
114
135
}
115
136
@@ -130,36 +151,21 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
130
151
Timer timer;
131
152
timer.tic ();
132
153
// set feed variable
133
- std::map<std::string, const framework::LoDTensor *> feed_targets;
134
154
std::vector<framework::LoDTensor> feeds;
135
- if (!SetFeed (inputs, &feeds)) {
155
+ framework::Scope *scope = sub_scope_ != nullptr ? sub_scope_ : scope_.get ();
156
+ if (!SetFeed (inputs, scope)) {
136
157
LOG (ERROR) << " fail to set feed" ;
137
158
return false ;
138
159
}
139
- for (size_t i = 0 ; i < feed_target_names_.size (); ++i) {
140
- if (config_.specify_input_name ) {
141
- feed_targets[inputs[i].name ] = &feeds[i];
142
- } else {
143
- feed_targets[feed_target_names_[i]] = &feeds[i];
144
- }
145
- }
146
- // get fetch variable
147
- std::map<std::string, framework::LoDTensor *> fetch_targets;
148
- std::vector<framework::LoDTensor> fetchs;
149
- fetchs.resize (fetch_target_names_.size ());
150
- for (size_t i = 0 ; i < fetch_target_names_.size (); ++i) {
151
- fetch_targets[fetch_target_names_[i]] = &fetchs[i];
152
- }
153
160
// Run the inference program
154
161
// if share variables, we need not create variables
155
162
VLOG (4 ) << " Run prepared context" ;
156
- executor_->RunPreparedContext (
157
- ctx_.get (), sub_scope_ != nullptr ? sub_scope_ : scope_.get (),
158
- &feed_targets, &fetch_targets,
159
- false , /* don't create local scope each time*/
160
- false /* don't create variable eatch time */ );
163
+ executor_->RunPreparedContext (ctx_.get (), scope,
164
+ false , /* don't create local scope each time*/
165
+ false /* don't create variable eatch time */ );
161
166
VLOG (4 ) << " Finish prepared context" ;
162
- if (!GetFetch (fetchs, output_data)) {
167
+ // get fetch variable
168
+ if (!GetFetch (output_data, scope)) {
163
169
LOG (ERROR) << " fail to get fetches" ;
164
170
return false ;
165
171
}
@@ -180,13 +186,13 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
180
186
}
181
187
182
188
bool NativePaddlePredictor::SetFeed (const std::vector<PaddleTensor> &inputs,
183
- std::vector< framework::LoDTensor> *feeds ) {
189
+ framework::Scope *scope ) {
184
190
VLOG (3 ) << " Predictor::set_feed" ;
185
- if (inputs.size () != feed_target_names_ .size ()) {
191
+ if (inputs.size () != feeds_ .size ()) {
186
192
LOG (ERROR) << " wrong feed input size." ;
187
193
return false ;
188
194
}
189
- for (size_t i = 0 ; i < feed_target_names_ .size (); ++i) {
195
+ for (size_t i = 0 ; i < inputs .size (); ++i) {
190
196
framework::LoDTensor input;
191
197
framework::DDim ddim = framework::make_ddim (inputs[i].shape );
192
198
void *input_ptr;
@@ -208,29 +214,40 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
208
214
lod.emplace_back (level);
209
215
}
210
216
input.set_lod (lod);
211
-
212
- feeds->push_back (input);
217
+ int idx = -1 ;
218
+ if (config_.specify_input_name ) {
219
+ idx =
220
+ boost::get<int >(feeds_[feed_names_[inputs[i].name ]]->GetAttr (" col" ));
221
+ } else {
222
+ idx = boost::get<int >(feeds_[i]->GetAttr (" col" ));
223
+ }
224
+ framework::SetFeedVariable (scope, input, " feed" , idx);
213
225
}
214
226
return true ;
215
227
}
216
228
217
- bool NativePaddlePredictor::GetFetch (
218
- const std::vector<framework::LoDTensor> &fetchs,
219
- std::vector<PaddleTensor> *outputs) {
229
+ bool NativePaddlePredictor::GetFetch (std::vector<PaddleTensor> *outputs,
230
+ framework::Scope *scope) {
220
231
VLOG (3 ) << " Predictor::get_fetch" ;
221
- outputs->resize (fetchs.size ());
222
- for (size_t i = 0 ; i < fetchs.size (); ++i) {
232
+ outputs->resize (fetchs_.size ());
233
+ for (size_t i = 0 ; i < fetchs_.size (); ++i) {
234
+ std::string fetch_target_name = fetchs_[i]->Input (" X" )[0 ];
235
+ int idx = boost::get<int >(fetchs_[i]->GetAttr (" col" ));
236
+ PADDLE_ENFORCE (idx == i);
237
+ framework::LoDTensor &output =
238
+ framework::GetFetchVariable (*scope, " fetch" , idx);
223
239
// TODO(panyx0718): Support fetch of other types.
224
- if (fetchs[i] .type () != typeid (float )) {
240
+ if (output .type () != typeid (float )) {
225
241
LOG (ERROR) << " only support fetching float now." ;
226
242
return false ;
227
243
}
244
+
228
245
std::vector<int > shape;
229
- auto dims_i = fetchs[i] .dims ();
230
- auto lod = fetchs[i] .lod ();
231
- const float *output_ptr = fetchs[i] .data <float >();
246
+ auto dims_i = output .dims ();
247
+ auto lod = output .lod ();
248
+ const float *output_ptr = output .data <float >();
232
249
// const int64_t* output_ptr = fetchs[i].data<int64_t>();
233
- auto num = fetchs[i] .numel ();
250
+ auto num = output .numel ();
234
251
std::vector<float > data;
235
252
if (0 == lod.size ()) {
236
253
std::copy (output_ptr, output_ptr + num, std::back_inserter (data));
@@ -275,7 +292,7 @@ bool NativePaddlePredictor::GetFetch(
275
292
}
276
293
std::memcpy (buffer.data (), data.data (), buffer.length ());
277
294
// copy LoD
278
- for (const auto &level : fetchs[i] .lod ()) {
295
+ for (const auto &level : output .lod ()) {
279
296
outputs->at (i).lod .emplace_back (level);
280
297
}
281
298
outputs->at (i).dtype = PaddleDType::FLOAT32;
0 commit comments