Skip to content

Commit 480a33d

Browse files
FrostMLZeyuChen
andauthored
[Faster Transformer] Refine transformer cpp inference demo (#575)
* refine faster transformer transformer demo Co-authored-by: Zeyu Chen <[email protected]>
1 parent 57ce415 commit 480a33d

File tree

4 files changed

+45
-35
lines changed

4 files changed

+45
-35
lines changed

examples/machine_translation/transformer/faster_transformer/README.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,19 +237,14 @@ cd ../
237237

238238
编译完成后,在 `build/bin/` 路径下将会看到 `transformer_e2e` 的一个可执行文件。通过设置对应的设置参数完成执行的过程。
239239

240-
``` sh
241-
cd bin/
242-
./transformer_e2e <batch_size> <gpu_id> <model_directory> <dict_directory> <input_data>
243-
```
244-
245240
### 导出基于 Faster Transformer 自定义 op 的预测库可使用模型文件
246241

247242
我们提供一个已经基于动态图训练好的 base model 的 checkpoint 以供使用,当前 checkpoint 是基于 WMT 英德翻译的任务训练。可以通过[tranformer-base-wmt_ende_bpe](https://paddlenlp.bj.bcebos.com/models/transformers/transformer/tranformer-base-wmt_ende_bpe.tar.gz)下载。
248243

249244
使用 C++ 预测库,首先,我们需要做的是将动态图的 checkpoint 导出成预测库能使用的模型文件和参数文件。可以执行 `export_model.py` 实现这个过程。
250245

251246
``` sh
252-
python export_model.py --config ../configs/transformer.base.yaml --decoding_lib ../../../../paddlenlp/ops/src/build/lib/libdecoding_op.so --decoding_strategy beam_search --beam_size 5
247+
python export_model.py --config ../configs/transformer.base.yaml --decoding_lib ../../../../paddlenlp/ops/build/lib/libdecoding_op.so --decoding_strategy beam_search --beam_size 5
253248
```
254249

255250
注意:这里的 `libdecoding_op.so` 的动态库是参照前文 **`Python 动态图使用自定义 op`** 编译出来的 lib,当前 **`C++ 预测库使用自定义 op`** 不包含编译的动态库。因此,如果在使用预测库前,还需要额外导出模型,需要编译两次:
@@ -269,7 +264,7 @@ python export_model.py --config ../configs/transformer.base.yaml --decoding_lib
269264

270265
``` sh
271266
cd bin/
272-
./transformer_e2e <batch_size> <gpu_id> <model_directory> <dict_directory> <input_data>
267+
./transformer_e2e -batch_size <batch_size> -beam_size <beam_size> -gpu_id <gpu_id> -model_dir <model_directory> -vocab_dir <dict_directory> -data_dir <input_data>
273268
```
274269

275270
这里的 `<model_directory>` 即是上文说到导出的 paddle inference 模型。
@@ -279,7 +274,7 @@ cd bin/
279274
``` sh
280275
cd bin/
281276
../third-party/build/bin/decoding_gemm 8 5 8 64 38512 256 512 0
282-
./transformer_e2e 8 0 ./infer_model/ DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en
277+
./transformer_e2e -batch_size 8 -beam_size 5 -gpu_id 0 -model_dir ./infer_model/ -vocab_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 -data_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en
283278
```
284279

285280
其中:

examples/machine_translation/transformer/faster_transformer/export_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def parse_args():
2929
help="Path of the config file. ")
3030
parser.add_argument(
3131
"--decoding_lib",
32-
default="../../../../paddlenlp/ops/src/build/lib/libdecoding_op.so",
32+
default="../../../../paddlenlp/ops/build/lib/libdecoding_op.so",
3333
type=str,
3434
help="Path of libdecoding_op.so. ")
3535
parser.add_argument(

paddlenlp/ops/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,15 @@ cd ../
236236

237237
``` sh
238238
cd bin/
239-
./transformer_e2e <batch_size> <gpu_id> <model_directory> <dict_directory> <input_data>
239+
./transformer_e2e -batch_size <batch_size> -beam_size <beam_size> -gpu_id <gpu_id> -model_dir <model_directory> -vocab_dir <dict_directory> -data_dir <input_data>
240240
```
241241

242242
举例说明:
243243

244244
``` sh
245245
cd bin/
246246
../third-party/build/bin/decoding_gemm 8 5 8 64 38512 256 512 0
247-
./transformer_e2e 8 0 ./infer_model/ DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en
247+
./transformer_e2e -batch_size 8 -beam_size 5 -gpu_id 0 -model_dir ./infer_model/ -vocab_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 -data_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en
248248
```
249249

250250
其中:

paddlenlp/ops/faster_transformer/src/demo/transformer_e2e.cc

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,31 @@ limitations under the License. */
3333

3434
using namespace paddle_infer;
3535

36+
DEFINE_int32(batch_size, 1, "Batch size to do inference. ");
37+
DEFINE_int32(beam_size, 5, "Beam size to do inference. ");
38+
DEFINE_int32(gpu_id, 0, "The gpu id to do inference. ");
39+
DEFINE_string(model_dir,
40+
"./infer_model/",
41+
"The directory to the inference model. ");
42+
DEFINE_string(vocab_dir,
43+
"./vocab_all.bpe.33708",
44+
"The directory to the vocabulary file. ");
45+
DEFINE_string(data_dir,
46+
"./newstest2014.tok.bpe.33708.en",
47+
"The directory to the input data. ");
3648

3749
std::string model_dir = "";
38-
std::string dict_dir = "";
39-
std::string datapath = "";
50+
std::string vocab_dir = "";
51+
std::string data_dir = "";
4052

41-
const int eos_idx = 1;
42-
const int pad_idx = 0;
43-
const int beam_size = 5;
44-
const int max_length = 256;
45-
const int n_best = 1;
53+
const int EOS_IDX = 1;
54+
const int PAD_IDX = 0;
55+
const int MAX_LENGTH = 256;
56+
const int N_BEST = 1;
4657

4758
int batch_size = 1;
4859
int gpu_id = 0;
60+
int beam_size = 5;
4961

5062
namespace paddle {
5163
namespace inference {
@@ -69,18 +81,18 @@ bool get_result_tensor(const std::unique_ptr<paddle_infer::Tensor>& seq_ids,
6981
seq_ids_out.resize(out_num);
7082
seq_ids->CopyToCpu(seq_ids_out.data());
7183

72-
dataresultvec.resize(batch_size * n_best);
84+
dataresultvec.resize(batch_size * N_BEST);
7385
auto max_output_length = output_shape[0];
7486

7587
for (int bsz = 0; bsz < output_shape[1]; ++bsz) {
76-
for (int k = 0; k < n_best; ++k) {
77-
dataresultvec[bsz * n_best + k].result_q = "";
88+
for (int k = 0; k < N_BEST; ++k) {
89+
dataresultvec[bsz * N_BEST + k].result_q = "";
7890
for (int len = 0; len < max_output_length; ++len) {
7991
if (seq_ids_out[len * batch_size * beam_size + bsz * beam_size + k] ==
80-
eos_idx)
92+
EOS_IDX)
8193
break;
82-
dataresultvec[bsz * n_best + k].result_q =
83-
dataresultvec[bsz * n_best + k].result_q +
94+
dataresultvec[bsz * N_BEST + k].result_q =
95+
dataresultvec[bsz * N_BEST + k].result_q +
8496
num2word_dict[seq_ids_out[len * batch_size * beam_size +
8597
bsz * beam_size + k]] +
8698
" ";
@@ -110,7 +122,7 @@ class DataReader {
110122
split(line, ' ', &word_data);
111123
std::string query_str = "";
112124
for (int j = 0; j < word_data.size(); ++j) {
113-
if (j >= max_length) {
125+
if (j >= MAX_LENGTH) {
114126
break;
115127
}
116128
query_str += word_data[j];
@@ -121,9 +133,9 @@ class DataReader {
121133
}
122134
}
123135
source_query_vec.push_back(query_str);
124-
data_input.src_data.push_back(eos_idx);
136+
data_input.src_data.push_back(EOS_IDX);
125137
max_len = std::max(max_len, static_cast<int>(data_input.src_data.size()));
126-
max_len = std::min(max_len, max_length);
138+
max_len = std::min(max_len, MAX_LENGTH);
127139
data_input_vec.push_back(data_input);
128140
}
129141
if (data_input_vec.empty()) {
@@ -134,7 +146,7 @@ class DataReader {
134146
}
135147

136148
bool GetWordDict() {
137-
std::ifstream fin(dict_dir);
149+
std::ifstream fin(vocab_dir);
138150
std::string line;
139151
int k = 0;
140152
while (std::getline(fin, line)) {
@@ -165,7 +177,7 @@ class DataReader {
165177
if (k < data_input_vec[i].src_data.size()) {
166178
src_word_vec[i * max_len + k] = data_input_vec[i].src_data[k];
167179
} else {
168-
src_word_vec[i * max_len + k] = pad_idx;
180+
src_word_vec[i * max_len + k] = PAD_IDX;
169181
}
170182
}
171183
}
@@ -204,7 +216,7 @@ void Main(int batch_size, int gpu_id) {
204216
config.SwitchUseFeedFetchOps(false);
205217
config.SwitchSpecifyInputNames(true);
206218
auto predictor = CreatePredictor(config);
207-
DataReader reader(datapath);
219+
DataReader reader(data_dir);
208220
reader.GetWordDict();
209221

210222
double whole_time = 0;
@@ -242,12 +254,15 @@ void Main(int batch_size, int gpu_id) {
242254
} // namespace paddle
243255

244256
int main(int argc, char** argv) {
245-
batch_size = std::stoi(std::string(argv[1]));
246-
gpu_id = std::stoi(std::string(argv[2]));
257+
gflags::ParseCommandLineFlags(&argc, &argv, true);
247258

248-
model_dir = std::string(argv[3]);
249-
dict_dir = std::string(argv[4]);
250-
datapath = std::string(argv[5]);
259+
batch_size = FLAGS_batch_size;
260+
gpu_id = FLAGS_gpu_id;
261+
beam_size = FLAGS_beam_size;
262+
263+
model_dir = FLAGS_model_dir;
264+
vocab_dir = FLAGS_vocab_dir;
265+
data_dir = FLAGS_data_dir;
251266

252267
paddle::inference::Main(batch_size, gpu_id);
253268

0 commit comments

Comments
 (0)