Skip to content

Commit 6de0a18

Browse files
authored
Refine/text classification support data (#13256)
1 parent 11b2288 commit 6de0a18

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,17 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc
100100

101101

102102
set(TEXT_CLASSIFICATION_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz")
103+
set(TEXT_CLASSIFICATION_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/text_classification_data.txt.tar.gz")
103104
set(TEXT_CLASSIFICATION_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/text_classification" CACHE PATH "Text Classification model and data root." FORCE)
104105

105106
if (NOT EXISTS ${TEXT_CLASSIFICATION_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
106107
inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_MODEL_URL} "text-classification-Senta.tar.gz")
108+
inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_DATA_URL} "text_classification_data.txt.tar.gz")
107109
endif()
108110

109111
inference_analysis_test(test_text_classification SRCS analyzer_text_classification_tester.cc
110112
EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor
111-
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta)
113+
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta
114+
--infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt
115+
--topn=1 # Just run top 1 batch.
116+
)

paddle/fluid/inference/analysis/analyzer_text_classification_tester.cc

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
#include <gflags/gflags.h>
1717
#include <glog/logging.h> // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files.
1818
#include <gtest/gtest.h>
19+
#include <fstream>
1920
#include "paddle/fluid/framework/ir/pass.h"
2021
#include "paddle/fluid/inference/analysis/ut_helper.h"
22+
#include "paddle/fluid/inference/api/helper.h"
2123
#include "paddle/fluid/inference/api/paddle_inference_api.h"
2224
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
2325
#include "paddle/fluid/inference/api/timer.h"
@@ -26,6 +28,7 @@ DEFINE_string(infer_model, "", "Directory of the inference model.");
2628
DEFINE_string(infer_data, "", "Path of the dataset.");
2729
DEFINE_int32(batch_size, 1, "batch size.");
2830
DEFINE_int32(repeat, 1, "How many times to repeat run.");
31+
DEFINE_int32(topn, -1, "Run top n batches of data to save time");
2932

3033
namespace paddle {
3134

@@ -45,41 +48,67 @@ void PrintTime(const double latency, const int bs, const int repeat) {
4548
LOG(INFO) << "=====================================";
4649
}
4750

48-
void Main(int batch_size) {
49-
// Three sequence inputs.
50-
std::vector<PaddleTensor> input_slots(1);
51-
// one batch starts
52-
// data --
53-
int64_t data0[] = {0, 1, 2};
54-
for (auto &input : input_slots) {
55-
input.data.Reset(data0, sizeof(data0));
56-
input.shape = std::vector<int>({3, 1});
57-
// dtype --
58-
input.dtype = PaddleDType::INT64;
59-
// LoD --
60-
input.lod = std::vector<std::vector<size_t>>({{0, 3}});
51+
struct DataReader {
52+
DataReader(const std::string &path) : file(new std::ifstream(path)) {}
53+
54+
bool NextBatch(PaddleTensor *tensor, int batch_size) {
55+
PADDLE_ENFORCE_EQ(batch_size, 1);
56+
std::string line;
57+
tensor->lod.clear();
58+
tensor->lod.emplace_back(std::vector<size_t>({0}));
59+
std::vector<int64_t> data;
60+
61+
for (int i = 0; i < batch_size; i++) {
62+
if (!std::getline(*file, line)) return false;
63+
inference::split_to_int64(line, ' ', &data);
64+
}
65+
tensor->lod.front().push_back(data.size());
66+
67+
tensor->data.Resize(data.size() * sizeof(int64_t));
68+
memcpy(tensor->data.data(), data.data(), data.size() * sizeof(int64_t));
69+
tensor->shape.clear();
70+
tensor->shape.push_back(data.size());
71+
tensor->shape.push_back(1);
72+
return true;
6173
}
6274

75+
std::unique_ptr<std::ifstream> file;
76+
};
77+
78+
void Main(int batch_size) {
6379
// shape --
6480
// Create Predictor --
6581
AnalysisConfig config;
6682
config.model_dir = FLAGS_infer_model;
6783
config.use_gpu = false;
6884
config.enable_ir_optim = true;
69-
config.ir_passes.push_back("fc_lstm_fuse_pass");
7085
auto predictor =
7186
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
7287
config);
7388

89+
std::vector<PaddleTensor> input_slots(1);
90+
// one batch starts
91+
// data --
92+
auto &input = input_slots[0];
93+
input.dtype = PaddleDType::INT64;
94+
7495
inference::Timer timer;
7596
double sum = 0;
7697
std::vector<PaddleTensor> output_slots;
77-
for (int i = 0; i < FLAGS_repeat; i++) {
78-
timer.tic();
79-
CHECK(predictor->Run(input_slots, &output_slots));
80-
sum += timer.toc();
98+
99+
int num_batches = 0;
100+
for (int t = 0; t < FLAGS_repeat; t++) {
101+
DataReader reader(FLAGS_infer_data);
102+
while (reader.NextBatch(&input, FLAGS_batch_size)) {
103+
if (FLAGS_topn > 0 && num_batches > FLAGS_topn) break;
104+
timer.tic();
105+
CHECK(predictor->Run(input_slots, &output_slots));
106+
sum += timer.toc();
107+
++num_batches;
108+
}
81109
}
82-
PrintTime(sum, batch_size, FLAGS_repeat);
110+
111+
PrintTime(sum, batch_size, num_batches);
83112

84113
// Get output
85114
LOG(INFO) << "get outputs " << output_slots.size();

0 commit comments

Comments
 (0)