@@ -33,19 +33,31 @@ limitations under the License. */
33
33
34
34
using namespace paddle_infer ;
35
35
36
+ DEFINE_int32 (batch_size, 1 , " Batch size to do inference. " );
37
+ DEFINE_int32 (beam_size, 5 , " Beam size to do inference. " );
38
+ DEFINE_int32 (gpu_id, 0 , " The gpu id to do inference. " );
39
+ DEFINE_string (model_dir,
40
+ " ./infer_model/" ,
41
+ " The directory to the inference model. " );
42
+ DEFINE_string (vocab_dir,
43
+ " ./vocab_all.bpe.33708" ,
44
+ " The directory to the vocabulary file. " );
45
+ DEFINE_string (data_dir,
46
+ " ./newstest2014.tok.bpe.33708.en" ,
47
+ " The directory to the input data. " );
36
48
37
49
std::string model_dir = " " ;
38
- std::string dict_dir = " " ;
39
- std::string datapath = " " ;
50
+ std::string vocab_dir = " " ;
51
+ std::string data_dir = " " ;
40
52
41
- const int eos_idx = 1 ;
42
- const int pad_idx = 0 ;
43
- const int beam_size = 5 ;
44
- const int max_length = 256 ;
45
- const int n_best = 1 ;
53
+ const int EOS_IDX = 1 ;
54
+ const int PAD_IDX = 0 ;
55
+ const int MAX_LENGTH = 256 ;
56
+ const int N_BEST = 1 ;
46
57
47
58
int batch_size = 1 ;
48
59
int gpu_id = 0 ;
60
+ int beam_size = 5 ;
49
61
50
62
namespace paddle {
51
63
namespace inference {
@@ -69,18 +81,18 @@ bool get_result_tensor(const std::unique_ptr<paddle_infer::Tensor>& seq_ids,
69
81
seq_ids_out.resize (out_num);
70
82
seq_ids->CopyToCpu (seq_ids_out.data ());
71
83
72
- dataresultvec.resize (batch_size * n_best );
84
+ dataresultvec.resize (batch_size * N_BEST );
73
85
auto max_output_length = output_shape[0 ];
74
86
75
87
for (int bsz = 0 ; bsz < output_shape[1 ]; ++bsz) {
76
- for (int k = 0 ; k < n_best ; ++k) {
77
- dataresultvec[bsz * n_best + k].result_q = " " ;
88
+ for (int k = 0 ; k < N_BEST ; ++k) {
89
+ dataresultvec[bsz * N_BEST + k].result_q = " " ;
78
90
for (int len = 0 ; len < max_output_length; ++len) {
79
91
if (seq_ids_out[len * batch_size * beam_size + bsz * beam_size + k] ==
80
- eos_idx )
92
+ EOS_IDX )
81
93
break ;
82
- dataresultvec[bsz * n_best + k].result_q =
83
- dataresultvec[bsz * n_best + k].result_q +
94
+ dataresultvec[bsz * N_BEST + k].result_q =
95
+ dataresultvec[bsz * N_BEST + k].result_q +
84
96
num2word_dict[seq_ids_out[len * batch_size * beam_size +
85
97
bsz * beam_size + k]] +
86
98
" " ;
@@ -110,7 +122,7 @@ class DataReader {
110
122
split (line, ' ' , &word_data);
111
123
std::string query_str = " " ;
112
124
for (int j = 0 ; j < word_data.size (); ++j) {
113
- if (j >= max_length ) {
125
+ if (j >= MAX_LENGTH ) {
114
126
break ;
115
127
}
116
128
query_str += word_data[j];
@@ -121,9 +133,9 @@ class DataReader {
121
133
}
122
134
}
123
135
source_query_vec.push_back (query_str);
124
- data_input.src_data .push_back (eos_idx );
136
+ data_input.src_data .push_back (EOS_IDX );
125
137
max_len = std::max (max_len, static_cast <int >(data_input.src_data .size ()));
126
- max_len = std::min (max_len, max_length );
138
+ max_len = std::min (max_len, MAX_LENGTH );
127
139
data_input_vec.push_back (data_input);
128
140
}
129
141
if (data_input_vec.empty ()) {
@@ -134,7 +146,7 @@ class DataReader {
134
146
}
135
147
136
148
bool GetWordDict () {
137
- std::ifstream fin (dict_dir );
149
+ std::ifstream fin (vocab_dir );
138
150
std::string line;
139
151
int k = 0 ;
140
152
while (std::getline (fin, line)) {
@@ -165,7 +177,7 @@ class DataReader {
165
177
if (k < data_input_vec[i].src_data .size ()) {
166
178
src_word_vec[i * max_len + k] = data_input_vec[i].src_data [k];
167
179
} else {
168
- src_word_vec[i * max_len + k] = pad_idx ;
180
+ src_word_vec[i * max_len + k] = PAD_IDX ;
169
181
}
170
182
}
171
183
}
@@ -204,7 +216,7 @@ void Main(int batch_size, int gpu_id) {
204
216
config.SwitchUseFeedFetchOps (false );
205
217
config.SwitchSpecifyInputNames (true );
206
218
auto predictor = CreatePredictor (config);
207
- DataReader reader (datapath );
219
+ DataReader reader (data_dir );
208
220
reader.GetWordDict ();
209
221
210
222
double whole_time = 0 ;
@@ -242,12 +254,15 @@ void Main(int batch_size, int gpu_id) {
242
254
} // namespace paddle
243
255
244
256
int main (int argc, char ** argv) {
245
- batch_size = std::stoi (std::string (argv[1 ]));
246
- gpu_id = std::stoi (std::string (argv[2 ]));
257
+ gflags::ParseCommandLineFlags (&argc, &argv, true );
247
258
248
- model_dir = std::string (argv[3 ]);
249
- dict_dir = std::string (argv[4 ]);
250
- datapath = std::string (argv[5 ]);
259
+ batch_size = FLAGS_batch_size;
260
+ gpu_id = FLAGS_gpu_id;
261
+ beam_size = FLAGS_beam_size;
262
+
263
+ model_dir = FLAGS_model_dir;
264
+ vocab_dir = FLAGS_vocab_dir;
265
+ data_dir = FLAGS_data_dir;
251
266
252
267
paddle::inference::Main (batch_size, gpu_id);
253
268
0 commit comments