Skip to content

Commit cc61893

Browse files
committed
Merge branch 'bert_test' of https://github.com/fc500110/Paddle into fc500110-bert_test
2 parents 07dc5a1 + 4a33a44 commit cc61893

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
115115
endif()
116116
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL)
117117

118+
# bert
119+
set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert")
120+
download_model_and_data(${BERT_INSTALL_DIR} "bert_model.tar.gz" "bert_data.txt.tar.gz")
121+
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc)
122+
118123
# resnet50
119124
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
120125
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL)
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gflags/gflags.h>
16+
#include <glog/logging.h>
17+
#include <chrono>
18+
#include <fstream>
19+
#include <numeric>
20+
#include <sstream>
21+
#include <string>
22+
#include <vector>
23+
#include "paddle/fluid/inference/api/paddle_inference_api.h"
24+
25+
DEFINE_int32(repeat, 1, "repeat");
26+
27+
namespace paddle {
28+
namespace inference {
29+
30+
using paddle::PaddleTensor;
31+
using paddle::contrib::AnalysisConfig;
32+
33+
template <typename T>
34+
void GetValueFromStream(std::stringstream *ss, T *t) {
35+
(*ss) >> (*t);
36+
}
37+
38+
template <>
39+
void GetValueFromStream<std::string>(std::stringstream *ss, std::string *t) {
40+
*t = ss->str();
41+
}
42+
43+
// Split string to vector
44+
template <typename T>
45+
void Split(const std::string &line, char sep, std::vector<T> *v) {
46+
std::stringstream ss;
47+
T t;
48+
for (auto c : line) {
49+
if (c != sep) {
50+
ss << c;
51+
} else {
52+
GetValueFromStream<T>(&ss, &t);
53+
v->push_back(std::move(t));
54+
ss.str({});
55+
ss.clear();
56+
}
57+
}
58+
59+
if (!ss.str().empty()) {
60+
GetValueFromStream<T>(&ss, &t);
61+
v->push_back(std::move(t));
62+
ss.str({});
63+
ss.clear();
64+
}
65+
}
66+
67+
template <typename T>
68+
constexpr paddle::PaddleDType GetPaddleDType();
69+
70+
template <>
71+
constexpr paddle::PaddleDType GetPaddleDType<int64_t>() {
72+
return paddle::PaddleDType::INT64;
73+
}
74+
75+
template <>
76+
constexpr paddle::PaddleDType GetPaddleDType<float>() {
77+
return paddle::PaddleDType::FLOAT32;
78+
}
79+
80+
// Parse tensor from string
81+
template <typename T>
82+
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
83+
std::vector<std::string> data;
84+
Split(field, ':', &data);
85+
if (data.size() < 2) return false;
86+
87+
std::string shape_str = data[0];
88+
89+
std::vector<int> shape;
90+
Split(shape_str, ' ', &shape);
91+
92+
std::string mat_str = data[1];
93+
94+
std::vector<T> mat;
95+
Split(mat_str, ' ', &mat);
96+
97+
tensor->shape = shape;
98+
auto size =
99+
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
100+
sizeof(T);
101+
tensor->data.Resize(size);
102+
std::copy(mat.begin(), mat.end(), static_cast<T *>(tensor->data.data()));
103+
tensor->dtype = GetPaddleDType<T>();
104+
105+
return true;
106+
}
107+
108+
// Parse input tensors from string
109+
bool ParseLine(const std::string &line,
110+
std::vector<paddle::PaddleTensor> *tensors) {
111+
std::vector<std::string> fields;
112+
Split(line, ';', &fields);
113+
114+
if (fields.size() < 5) return false;
115+
116+
tensors->clear();
117+
tensors->reserve(5);
118+
119+
int i = 0;
120+
// src_id
121+
paddle::PaddleTensor src_id;
122+
ParseTensor<int64_t>(fields[i++], &src_id);
123+
tensors->push_back(src_id);
124+
125+
// pos_id
126+
paddle::PaddleTensor pos_id;
127+
ParseTensor<int64_t>(fields[i++], &pos_id);
128+
tensors->push_back(pos_id);
129+
130+
// segment_id
131+
paddle::PaddleTensor segment_id;
132+
ParseTensor<int64_t>(fields[i++], &segment_id);
133+
tensors->push_back(segment_id);
134+
135+
// self_attention_bias
136+
paddle::PaddleTensor self_attention_bias;
137+
ParseTensor<float>(fields[i++], &self_attention_bias);
138+
tensors->push_back(self_attention_bias);
139+
140+
// next_segment_index
141+
paddle::PaddleTensor next_segment_index;
142+
ParseTensor<int64_t>(fields[i++], &next_segment_index);
143+
tensors->push_back(next_segment_index);
144+
145+
return true;
146+
}
147+
148+
// Print outputs to log
149+
void PrintOutputs(const std::vector<paddle::PaddleTensor> &outputs) {
150+
LOG(INFO) << "example_id\tcontradiction\tentailment\tneutral";
151+
152+
for (size_t i = 0; i < outputs.front().data.length(); i += 3) {
153+
LOG(INFO) << (i / 3) << "\t"
154+
<< static_cast<float *>(outputs.front().data.data())[i] << "\t"
155+
<< static_cast<float *>(outputs.front().data.data())[i + 1]
156+
<< "\t"
157+
<< static_cast<float *>(outputs.front().data.data())[i + 2];
158+
}
159+
}
160+
161+
bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
162+
if (FLAGS_infer_data.empty()) {
163+
LOG(ERROR) << "please set input data path";
164+
return false;
165+
}
166+
167+
std::ifstream fin(FLAGS_infer_data);
168+
std::string line;
169+
170+
int lineno = 0;
171+
while (std::getline(fin, line)) {
172+
std::vector<paddle::PaddleTensor> feed_data;
173+
if (!ParseLine(line, &feed_data)) {
174+
LOG(ERROR) << "Parse line[" << lineno << "] error!";
175+
} else {
176+
inputs->push_back(std::move(feed_data));
177+
}
178+
}
179+
180+
return true;
181+
}
182+
183+
void SetConfig(contrib::AnalysisConfig *config) {
184+
config->SetModel(FLAGS_infer_model);
185+
}
186+
187+
void profile(bool use_mkldnn = false) {
188+
contrib::AnalysisConfig config;
189+
SetConfig(&config);
190+
191+
if (use_mkldnn) {
192+
config.EnableMKLDNN();
193+
}
194+
195+
std::vector<PaddleTensor> outputs;
196+
std::vector<std::vector<PaddleTensor>> inputs;
197+
LoadInputData(&inputs);
198+
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&config),
199+
inputs, &outputs, FLAGS_num_threads);
200+
}
201+
202+
void compare(bool use_mkldnn = false) {
203+
AnalysisConfig config;
204+
SetConfig(&config);
205+
206+
std::vector<std::vector<PaddleTensor>> inputs;
207+
LoadInputData(&inputs);
208+
CompareNativeAndAnalysis(
209+
reinterpret_cast<const PaddlePredictor::Config *>(&config), inputs);
210+
}
211+
212+
TEST(Analyzer_bert, profile) { profile(); }
213+
#ifdef PADDLE_WITH_MKLDNN
214+
TEST(Analyzer_bert, profile_mkldnn) { profile(true); }
215+
#endif
216+
} // namespace inference
217+
} // namespace paddle

0 commit comments

Comments
 (0)