Skip to content

Commit 353b5f0

Browse files
committed
refine analyzer_bert_test to pass the ci
test=develop
1 parent cc61893 commit 353b5f0

File tree

1 file changed

+47
-22
lines changed

1 file changed

+47
-22
lines changed

paddle/fluid/inference/tests/api/analyzer_bert_tester.cc

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include <gflags/gflags.h>
16-
#include <glog/logging.h>
17-
#include <chrono>
18-
#include <fstream>
19-
#include <numeric>
20-
#include <sstream>
21-
#include <string>
22-
#include <vector>
23-
#include "paddle/fluid/inference/api/paddle_inference_api.h"
24-
25-
DEFINE_int32(repeat, 1, "repeat");
15+
#include "paddle/fluid/inference/tests/api/tester_helper.h"
2616

2717
namespace paddle {
2818
namespace inference {
@@ -166,16 +156,17 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
166156

167157
std::ifstream fin(FLAGS_infer_data);
168158
std::string line;
159+
int sample = 0;
169160

170-
int lineno = 0;
161+
// The unit-test dataset only have 10 samples, each sample have 5 feeds.
171162
while (std::getline(fin, line)) {
172163
std::vector<paddle::PaddleTensor> feed_data;
173-
if (!ParseLine(line, &feed_data)) {
174-
LOG(ERROR) << "Parse line[" << lineno << "] error!";
175-
} else {
176-
inputs->push_back(std::move(feed_data));
177-
}
164+
ParseLine(line, &feed_data);
165+
inputs->push_back(std::move(feed_data));
166+
sample++;
167+
if (!FLAGS_test_all_data && sample == FLAGS_batch_size) break;
178168
}
169+
LOG(INFO) << "number of samples: " << sample;
179170

180171
return true;
181172
}
@@ -199,19 +190,53 @@ void profile(bool use_mkldnn = false) {
199190
inputs, &outputs, FLAGS_num_threads);
200191
}
201192

193+
TEST(Analyzer_bert, profile) { profile(); }
194+
#ifdef PADDLE_WITH_MKLDNN
195+
TEST(Analyzer_bert, profile_mkldnn) { profile(true); }
196+
#endif
197+
198+
// Check the fuse status
199+
TEST(Analyzer_bert, fuse_statis) {
200+
AnalysisConfig cfg;
201+
SetConfig(&cfg);
202+
int num_ops;
203+
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
204+
auto fuse_statis = GetFuseStatis(
205+
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
206+
LOG(INFO) << "num_ops: " << num_ops;
207+
}
208+
209+
// Compare result of NativeConfig and AnalysisConfig
202210
void compare(bool use_mkldnn = false) {
203-
AnalysisConfig config;
204-
SetConfig(&config);
211+
AnalysisConfig cfg;
212+
SetConfig(&cfg);
213+
if (use_mkldnn) {
214+
cfg.EnableMKLDNN();
215+
}
205216

206217
std::vector<std::vector<PaddleTensor>> inputs;
207218
LoadInputData(&inputs);
208219
CompareNativeAndAnalysis(
209-
reinterpret_cast<const PaddlePredictor::Config *>(&config), inputs);
220+
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), inputs);
210221
}
211222

212-
TEST(Analyzer_bert, profile) { profile(); }
223+
TEST(Analyzer_bert, compare) { compare(); }
213224
#ifdef PADDLE_WITH_MKLDNN
214-
TEST(Analyzer_bert, profile_mkldnn) { profile(true); }
225+
TEST(Analyzer_bert, compare_mkldnn) { compare(true /* use_mkldnn */); }
215226
#endif
227+
228+
// Compare Deterministic result
229+
// TODO(luotao): Since each unit-test on CI only have 10 minutes, cancel this to
230+
// decrease the CI time.
231+
// TEST(Analyzer_bert, compare_determine) {
232+
// AnalysisConfig cfg;
233+
// SetConfig(&cfg);
234+
//
235+
// std::vector<std::vector<PaddleTensor>> inputs;
236+
// LoadInputData(&inputs);
237+
// CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config
238+
// *>(&cfg),
239+
// inputs);
240+
// }
216241
} // namespace inference
217242
} // namespace paddle

0 commit comments

Comments
 (0)