Skip to content

Commit d0c65bf

Browse files
authored
Merge pull request #13100 from luotao1/ner_ut2
add unit-test for chinese_ner
2 parents 7e3a884 + b3cd2ae commit d0c65bf

File tree

6 files changed

+271
-87
lines changed

6 files changed

+271
-87
lines changed

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,20 @@ function (inference_analysis_test TARGET)
4040
endif(WITH_TESTING)
4141
endfunction(inference_analysis_test)
4242

43-
set(DITU_RNN_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fmodel.tar.gz")
44-
set(DITU_RNN_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fdata.txt.tar.gz")
45-
set(DITU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/ditu_rnn" CACHE PATH "Ditu RNN model and data root." FORCE)
46-
set(DITU_RNN_MODEL ${DITU_INSTALL_DIR}/model)
47-
set(DITU_RNN_DATA ${DITU_INSTALL_DIR}/data.txt)
48-
49-
function (inference_download_and_uncompress target url gz_filename)
43+
function (inference_download_and_uncompress install_dir url gz_filename)
5044
message(STATUS "Download inference test stuff ${gz_filename} from ${url}")
51-
execute_process(COMMAND bash -c "mkdir -p ${DITU_INSTALL_DIR}")
52-
execute_process(COMMAND bash -c "cd ${DITU_INSTALL_DIR} && wget -q ${url}")
53-
execute_process(COMMAND bash -c "cd ${DITU_INSTALL_DIR} && tar xzf ${gz_filename}")
45+
execute_process(COMMAND bash -c "mkdir -p ${install_dir}")
46+
execute_process(COMMAND bash -c "cd ${install_dir} && wget -q ${url}")
47+
execute_process(COMMAND bash -c "cd ${install_dir} && tar xzf ${gz_filename}")
5448
message(STATUS "finish downloading ${gz_filename}")
5549
endfunction(inference_download_and_uncompress)
5650

51+
set(DITU_RNN_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fmodel.tar.gz")
52+
set(DITU_RNN_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fdata.txt.tar.gz")
53+
set(DITU_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/ditu_rnn" CACHE PATH "Ditu RNN model and data root." FORCE)
5754
if (NOT EXISTS ${DITU_INSTALL_DIR})
58-
inference_download_and_uncompress(ditu_rnn_model ${DITU_RNN_MODEL_URL} "ditu_rnn_fluid%2Fmodel.tar.gz")
59-
inference_download_and_uncompress(ditu_rnn_data ${DITU_RNN_DATA_URL} "ditu_rnn_fluid%2Fdata.txt.tar.gz")
55+
inference_download_and_uncompress(${DITU_INSTALL_DIR} ${DITU_RNN_MODEL_URL} "ditu_rnn_fluid%2Fmodel.tar.gz")
56+
inference_download_and_uncompress(${DITU_INSTALL_DIR} ${DITU_RNN_DATA_URL} "ditu_rnn_fluid%2Fdata.txt.tar.gz")
6057
endif()
6158

6259
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
@@ -87,3 +84,17 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_
8784
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
8885
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
8986
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
87+
88+
set(CHINESE_NER_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner_model.tar.gz")
89+
set(CHINESE_NER_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner-data.txt.tar.gz")
90+
set(CHINESE_NER_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/chinese_ner" CACHE PATH "Chinese ner model and data root." FORCE)
91+
if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR})
92+
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_MODEL_URL} "chinese_ner_model.tar.gz")
93+
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_DATA_URL} "chinese_ner-data.txt.tar.gz")
94+
endif()
95+
96+
inference_analysis_test(test_chinese_ner SRCS chinese_ner_tester.cc
97+
EXTRA_DEPS paddle_inference_api paddle_fluid_api
98+
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
99+
--infer_model=${CHINESE_NER_INSTALL_DIR}/model
100+
--infer_data=${CHINESE_NER_INSTALL_DIR}/data.txt)

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,13 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
201201
minute_tensor.lod.assign({one_batch.lod3});
202202
// clang-format on
203203
// assign data
204-
TensorAssignData(&lod_attention_tensor,
205-
std::vector<std::vector<float>>({{0, 0}}));
204+
TensorAssignData<float>(&lod_attention_tensor,
205+
std::vector<std::vector<float>>({{0, 0}}));
206206
std::vector<float> tmp_zeros(batch_size * 15, 0.);
207-
TensorAssignData(&init_zero_tensor, {tmp_zeros});
208-
TensorAssignData(&lod_tensor_tensor, one_batch.rnn_link_data);
209-
TensorAssignData(&week_tensor, one_batch.rnn_week_datas);
210-
TensorAssignData(&minute_tensor, one_batch.rnn_minute_datas);
207+
TensorAssignData<float>(&init_zero_tensor, {tmp_zeros});
208+
TensorAssignData<float>(&lod_tensor_tensor, one_batch.rnn_link_data);
209+
TensorAssignData<float>(&week_tensor, one_batch.rnn_week_datas);
210+
TensorAssignData<float>(&minute_tensor, one_batch.rnn_minute_datas);
211211
// Set inputs.
212212
auto init_zero_tensor1 = init_zero_tensor;
213213
init_zero_tensor1.name = "hidden_init";
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <google/protobuf/text_format.h>
16+
#include <gtest/gtest.h>
17+
#include "paddle/fluid/framework/ir/pass.h"
18+
#include "paddle/fluid/inference/analysis/analyzer.h"
19+
#include "paddle/fluid/inference/analysis/ut_helper.h"
20+
#include "paddle/fluid/inference/api/helper.h"
21+
#include "paddle/fluid/inference/api/paddle_inference_api.h"
22+
#include "paddle/fluid/platform/profiler.h"
23+
24+
DEFINE_string(infer_model, "", "model path");
25+
DEFINE_string(infer_data, "", "data path");
26+
DEFINE_int32(batch_size, 10, "batch size.");
27+
DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
28+
29+
namespace paddle {
30+
namespace inference {
31+
32+
struct DataRecord {
33+
std::vector<std::vector<int64_t>> word_data_all, mention_data_all;
34+
std::vector<std::vector<int64_t>> rnn_word_datas, rnn_mention_datas;
35+
std::vector<size_t> lod; // two inputs have the same lod info.
36+
size_t batch_iter{0};
37+
size_t batch_size{1};
38+
DataRecord() = default;
39+
explicit DataRecord(const std::string &path, int batch_size = 1)
40+
: batch_size(batch_size) {
41+
Load(path);
42+
}
43+
DataRecord NextBatch() {
44+
DataRecord data;
45+
size_t batch_end = batch_iter + batch_size;
46+
// NOTE skip the final batch, if no enough data is provided.
47+
if (batch_end <= word_data_all.size()) {
48+
data.word_data_all.assign(word_data_all.begin() + batch_iter,
49+
word_data_all.begin() + batch_end);
50+
data.mention_data_all.assign(mention_data_all.begin() + batch_iter,
51+
mention_data_all.begin() + batch_end);
52+
// Prepare LoDs
53+
data.lod.push_back(0);
54+
CHECK(!data.word_data_all.empty());
55+
CHECK(!data.mention_data_all.empty());
56+
CHECK_EQ(data.word_data_all.size(), data.mention_data_all.size());
57+
for (size_t j = 0; j < data.word_data_all.size(); j++) {
58+
data.rnn_word_datas.push_back(data.word_data_all[j]);
59+
data.rnn_mention_datas.push_back(data.mention_data_all[j]);
60+
// calculate lod
61+
data.lod.push_back(data.lod.back() + data.word_data_all[j].size());
62+
}
63+
}
64+
batch_iter += batch_size;
65+
return data;
66+
}
67+
void Load(const std::string &path) {
68+
std::ifstream file(path);
69+
std::string line;
70+
int num_lines = 0;
71+
while (std::getline(file, line)) {
72+
num_lines++;
73+
std::vector<std::string> data;
74+
split(line, ';', &data);
75+
// load word data
76+
std::vector<int64_t> word_data;
77+
split_to_int64(data[1], ' ', &word_data);
78+
// load mention data
79+
std::vector<int64_t> mention_data;
80+
split_to_int64(data[3], ' ', &mention_data);
81+
word_data_all.push_back(std::move(word_data));
82+
mention_data_all.push_back(std::move(mention_data));
83+
}
84+
}
85+
};
86+
87+
void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
88+
int batch_size) {
89+
PaddleTensor lod_word_tensor, lod_mention_tensor;
90+
lod_word_tensor.name = "word";
91+
lod_mention_tensor.name = "mention";
92+
auto one_batch = data->NextBatch();
93+
int size = one_batch.lod[one_batch.lod.size() - 1]; // token batch size
94+
lod_word_tensor.shape.assign({size, 1});
95+
lod_word_tensor.lod.assign({one_batch.lod});
96+
lod_mention_tensor.shape.assign({size, 1});
97+
lod_mention_tensor.lod.assign({one_batch.lod});
98+
// assign data
99+
TensorAssignData<int64_t>(&lod_word_tensor, one_batch.rnn_word_datas);
100+
TensorAssignData<int64_t>(&lod_mention_tensor, one_batch.rnn_mention_datas);
101+
// Set inputs.
102+
input_slots->assign({lod_word_tensor, lod_mention_tensor});
103+
for (auto &tensor : *input_slots) {
104+
tensor.dtype = PaddleDType::INT64;
105+
}
106+
}
107+
108+
// the first inference result
109+
const int chinese_ner_result_data[] = {30, 45, 41, 48, 17, 26,
110+
48, 39, 38, 16, 25};
111+
112+
void TestChineseNERPrediction() {
113+
NativeConfig config;
114+
config.prog_file = FLAGS_infer_model + "/__model__";
115+
config.param_file = FLAGS_infer_model + "/param";
116+
config.use_gpu = false;
117+
config.device = 0;
118+
config.specify_input_name = true;
119+
120+
auto predictor =
121+
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
122+
std::vector<PaddleTensor> input_slots;
123+
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
124+
// Prepare inputs.
125+
PrepareInputs(&input_slots, &data, FLAGS_batch_size);
126+
std::vector<PaddleTensor> outputs;
127+
128+
Timer timer;
129+
timer.tic();
130+
for (int i = 0; i < FLAGS_repeat; i++) {
131+
predictor->Run(input_slots, &outputs);
132+
}
133+
LOG(INFO) << "===========profile result===========";
134+
LOG(INFO) << "batch_size: " << FLAGS_batch_size
135+
<< ", repeat: " << FLAGS_repeat
136+
<< ", latency: " << timer.toc() / FLAGS_repeat << "ms";
137+
LOG(INFO) << "=====================================";
138+
139+
PADDLE_ENFORCE(outputs.size(), 1UL);
140+
auto &out = outputs[0];
141+
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
142+
[](int a, int b) { return a * b; });
143+
PADDLE_ENFORCE_GT(size, 0);
144+
int64_t *result = static_cast<int64_t *>(out.data.data());
145+
for (size_t i = 0; i < std::min(11UL, size); i++) {
146+
PADDLE_ENFORCE(result[i], chinese_ner_result_data[i]);
147+
}
148+
}
149+
150+
// Directly infer with the original model.
151+
TEST(Analyzer, Chinese_ner) { TestChineseNERPrediction(); }
152+
153+
} // namespace inference
154+
} // namespace paddle

paddle/fluid/inference/api/api_impl.cc

Lines changed: 69 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ void NativePaddlePredictor::PrepareFeedFetch() {
6262
for (auto *op : inference_program_->Block(0).AllOps()) {
6363
if (op->Type() == "feed") {
6464
int idx = boost::get<int>(op->GetAttr("col"));
65-
if (feeds_.size() <= static_cast<size_t>(idx)) {
65+
if (feeds_.size() <= (size_t)idx) {
6666
feeds_.resize(idx + 1);
6767
}
6868
feeds_[idx] = op;
6969
feed_names_[op->Output("Out")[0]] = idx;
7070
} else if (op->Type() == "fetch") {
7171
int idx = boost::get<int>(op->GetAttr("col"));
72-
if (fetchs_.size() <= idx) {
72+
if (fetchs_.size() <= (size_t)idx) {
7373
fetchs_.resize(idx + 1);
7474
}
7575
fetchs_[idx] = op;
@@ -222,77 +222,83 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
222222
}
223223
return true;
224224
}
225+
template <typename T>
226+
void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch,
227+
PaddleTensor *output) {
228+
std::vector<int> shape;
229+
auto dims_i = fetch.dims();
230+
auto lod = fetch.lod();
231+
const T *output_ptr = fetch.data<T>();
232+
auto num = fetch.numel();
233+
std::vector<T> data;
234+
if (0 == lod.size()) {
235+
std::copy(output_ptr, output_ptr + num, std::back_inserter(data));
236+
for (int j = 0; j < dims_i.size(); ++j) {
237+
shape.push_back(dims_i[j]);
238+
}
239+
} else {
240+
// for batch detection
241+
// image[0] -> output[0] shape {145, 6}
242+
// image[1] -> output[1] shape {176, 6}
243+
// then,
244+
// the batch output shape {321, 6}
245+
// the lod {{0, 145, 321}}
246+
// so we should append output[0] to {176, 6}
247+
size_t max_dim = 0;
248+
for (size_t j = 1; j < lod[0].size(); j++) {
249+
max_dim = std::max(max_dim, lod[0][j] - lod[0][j - 1]);
250+
}
251+
size_t common_dim = lod[0].back() == 0 ? 0 : num / lod[0].back();
252+
if (max_dim > 0) {
253+
data.resize((lod[0].size() - 1) * max_dim * common_dim, 0);
254+
}
255+
for (size_t j = 1; j < lod[0].size(); j++) {
256+
size_t start = lod[0][j - 1] * common_dim;
257+
size_t end = lod[0][j] * common_dim;
258+
if (end > start) {
259+
std::copy(output_ptr + start, output_ptr + end,
260+
data.begin() + (j - 1) * max_dim * common_dim);
261+
}
262+
}
263+
shape.push_back(lod[0].size() - 1);
264+
shape.push_back(max_dim);
265+
for (int j = 1; j < dims_i.size(); ++j) {
266+
shape.push_back(dims_i[j]);
267+
}
268+
}
269+
270+
output->shape = shape;
271+
auto &buffer = output->data;
272+
if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) {
273+
buffer.Resize(sizeof(T) * data.size());
274+
}
275+
std::memcpy(buffer.data(), data.data(), buffer.length());
276+
// copy LoD
277+
for (const auto &level : fetch.lod()) {
278+
output->lod.emplace_back(level);
279+
}
280+
}
225281

226282
bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
227283
framework::Scope *scope) {
228284
VLOG(3) << "Predictor::get_fetch";
229285
outputs->resize(fetchs_.size());
230286
for (size_t i = 0; i < fetchs_.size(); ++i) {
231287
int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
232-
PADDLE_ENFORCE(idx == i);
233-
framework::LoDTensor &output =
288+
PADDLE_ENFORCE((size_t)idx == i);
289+
framework::LoDTensor &fetch =
234290
framework::GetFetchVariable(*scope, "fetch", idx);
235-
// TODO(panyx0718): Support fetch of other types.
236-
if (output.type() != typeid(float)) {
237-
LOG(ERROR) << "only support fetching float now.";
238-
return false;
239-
}
240-
241-
std::vector<int> shape;
242-
auto dims_i = output.dims();
243-
auto lod = output.lod();
244-
const float *output_ptr = output.data<float>();
245-
// const int64_t* output_ptr = fetchs[i].data<int64_t>();
246-
auto num = output.numel();
247-
std::vector<float> data;
248-
if (0 == lod.size()) {
249-
std::copy(output_ptr, output_ptr + num, std::back_inserter(data));
250-
for (int j = 0; j < dims_i.size(); ++j) {
251-
shape.push_back(dims_i[j]);
252-
}
291+
auto type = fetch.type();
292+
auto output = &(outputs->at(i));
293+
if (type == typeid(float)) {
294+
GetFetchOne<float>(fetch, output);
295+
output->dtype = PaddleDType::FLOAT32;
296+
} else if (type == typeid(int64_t)) {
297+
GetFetchOne<int64_t>(fetch, output);
298+
output->dtype = PaddleDType::INT64;
253299
} else {
254-
// for batch detection
255-
// image[0] -> output[0] shape {145, 6}
256-
// image[1] -> output[1] shape {176, 6}
257-
// then,
258-
// the batch output shape {321, 6}
259-
// the lod {{0, 145, 321}}
260-
// so we should append output[0] to {176, 6}
261-
size_t max_dim = 0;
262-
for (size_t j = 1; j < lod[0].size(); j++) {
263-
max_dim = std::max(max_dim, lod[0][j] - lod[0][j - 1]);
264-
}
265-
size_t common_dim = lod[0].back() == 0 ? 0 : num / lod[0].back();
266-
if (max_dim > 0) {
267-
data.resize((lod[0].size() - 1) * max_dim * common_dim, 0);
268-
}
269-
for (size_t j = 1; j < lod[0].size(); j++) {
270-
size_t start = lod[0][j - 1] * common_dim;
271-
size_t end = lod[0][j] * common_dim;
272-
if (end > start) {
273-
std::copy(output_ptr + start, output_ptr + end,
274-
data.begin() + (j - 1) * max_dim * common_dim);
275-
}
276-
}
277-
shape.push_back(lod[0].size() - 1);
278-
shape.push_back(max_dim);
279-
for (int j = 1; j < dims_i.size(); ++j) {
280-
shape.push_back(dims_i[j]);
281-
}
282-
}
283-
284-
outputs->at(i).shape = shape;
285-
auto &buffer = outputs->at(i).data;
286-
if (buffer.empty() || buffer.length() < sizeof(float) * data.size()) {
287-
buffer.Resize(sizeof(float) * data.size());
288-
}
289-
std::memcpy(buffer.data(), data.data(), buffer.length());
290-
// copy LoD
291-
for (const auto &level : output.lod()) {
292-
outputs->at(i).lod.emplace_back(level);
300+
LOG(ERROR) << "unknown type, only support float32 and int64 now.";
293301
}
294-
outputs->at(i).dtype = PaddleDType::FLOAT32;
295-
// TODO(panyx0718): support other types? fill tensor name? avoid a copy.
296302
}
297303
return true;
298304
}

paddle/fluid/inference/api/api_impl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ class NativePaddlePredictor : public PaddlePredictor {
5151
framework::Scope *scope);
5252
bool GetFetch(std::vector<PaddleTensor> *output_data,
5353
framework::Scope *scope);
54-
54+
template <typename T>
55+
void GetFetchOne(const framework::LoDTensor &fetchs,
56+
PaddleTensor *output_data);
5557
void PrepareFeedFetch();
5658

5759
NativeConfig config_;

0 commit comments

Comments
 (0)