Skip to content

Commit 5b24002

Browse files
authored
Merge pull request #16399 from sfraczek/sfraczek/analyzer_int8_resnet50_test
create test for quantized resnet50
2 parents 278deba + 8ece7a9 commit 5b24002

File tree

4 files changed

+268
-13
lines changed

4 files changed

+268
-13
lines changed

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ function(inference_analysis_api_test target install_dir filename)
2323
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt)
2424
endfunction()
2525

26+
function(inference_analysis_api_int8_test target model_dir data_dir filename)
27+
inference_analysis_test(${target} SRCS ${filename}
28+
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark
29+
ARGS --infer_model=${model_dir}/model --infer_data=${data_dir}/data.bin --batch_size=100)
30+
endfunction()
31+
2632
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name)
2733
download_model(${install_dir} ${model_name})
2834
inference_analysis_test(${target} SRCS ${filename}
@@ -138,6 +144,28 @@ inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
138144
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv
139145
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" SERIAL)
140146

147+
# int8 image classification tests
148+
if(WITH_MKLDNN)
149+
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8")
150+
if (NOT EXISTS ${INT8_DATA_DIR})
151+
inference_download_and_uncompress(${INT8_DATA_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "imagenet_val_100.tar.gz")
152+
endif()
153+
154+
#resnet50 int8
155+
set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50")
156+
if (NOT EXISTS ${INT8_RESNET50_MODEL_DIR})
157+
inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "resnet50_int8_model.tar.gz" )
158+
endif()
159+
inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL)
160+
161+
#mobilenet int8
162+
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet")
163+
if (NOT EXISTS ${INT8_MOBILENET_MODEL_DIR})
164+
inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "mobilenetv1_int8_model.tar.gz" )
165+
endif()
166+
inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL)
167+
endif()
168+
141169
# bert, max_len=20, embedding_dim=128
142170
set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert_emb128")
143171
download_model_and_data(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_data_len20.txt.tar.gz")

paddle/fluid/inference/tests/api/analyzer_bert_tester.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,6 @@ void Split(const std::string &line, char sep, std::vector<T> *v) {
5353
}
5454
}
5555

56-
template <typename T>
57-
constexpr paddle::PaddleDType GetPaddleDType();
58-
59-
template <>
60-
constexpr paddle::PaddleDType GetPaddleDType<int64_t>() {
61-
return paddle::PaddleDType::INT64;
62-
}
63-
64-
template <>
65-
constexpr paddle::PaddleDType GetPaddleDType<float>() {
66-
return paddle::PaddleDType::FLOAT32;
67-
}
68-
6956
// Parse tensor from string
7057
template <typename T>
7158
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <fstream>
16+
#include <iostream>
17+
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
18+
#include "paddle/fluid/inference/tests/api/tester_helper.h"
19+
20+
DEFINE_int32(iterations, 0, "Number of iterations");
21+
22+
namespace paddle {
23+
namespace inference {
24+
namespace analysis {
25+
26+
void SetConfig(AnalysisConfig *cfg) {
27+
cfg->SetModel(FLAGS_infer_model);
28+
cfg->SetProgFile("__model__");
29+
cfg->DisableGpu();
30+
cfg->SwitchIrOptim();
31+
cfg->SwitchSpecifyInputNames(false);
32+
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
33+
34+
cfg->EnableMKLDNN();
35+
}
36+
37+
template <typename T>
38+
class TensorReader {
39+
public:
40+
TensorReader(std::ifstream &file, size_t beginning_offset,
41+
std::vector<int> shape, std::string name)
42+
: file_(file), position(beginning_offset), shape_(shape), name_(name) {
43+
numel =
44+
std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<T>());
45+
}
46+
47+
PaddleTensor NextBatch() {
48+
PaddleTensor tensor;
49+
tensor.name = name_;
50+
tensor.shape = shape_;
51+
tensor.dtype = GetPaddleDType<T>();
52+
tensor.data.Resize(numel * sizeof(T));
53+
54+
file_.seekg(position);
55+
file_.read(static_cast<char *>(tensor.data.data()), numel * sizeof(T));
56+
position = file_.tellg();
57+
58+
if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream";
59+
if (file_.fail())
60+
throw std::runtime_error(name_ + ": failed reading file.");
61+
62+
return tensor;
63+
}
64+
65+
protected:
66+
std::ifstream &file_;
67+
size_t position;
68+
std::vector<int> shape_;
69+
std::string name_;
70+
size_t numel;
71+
};
72+
73+
std::shared_ptr<std::vector<PaddleTensor>> GetWarmupData(
74+
const std::vector<std::vector<PaddleTensor>> &test_data, int num_images) {
75+
int test_data_batch_size = test_data[0][0].shape[0];
76+
CHECK_LE(static_cast<size_t>(num_images),
77+
test_data.size() * test_data_batch_size);
78+
79+
PaddleTensor images;
80+
images.name = "input";
81+
images.shape = {num_images, 3, 224, 224};
82+
images.dtype = PaddleDType::FLOAT32;
83+
images.data.Resize(sizeof(float) * num_images * 3 * 224 * 224);
84+
85+
PaddleTensor labels;
86+
labels.name = "labels";
87+
labels.shape = {num_images, 1};
88+
labels.dtype = PaddleDType::INT64;
89+
labels.data.Resize(sizeof(int64_t) * num_images);
90+
91+
for (int i = 0; i < num_images; i++) {
92+
auto batch = i / test_data_batch_size;
93+
auto element_in_batch = i % test_data_batch_size;
94+
std::copy_n(static_cast<float *>(test_data[batch][0].data.data()) +
95+
element_in_batch * 3 * 224 * 224,
96+
3 * 224 * 224,
97+
static_cast<float *>(images.data.data()) + i * 3 * 224 * 224);
98+
99+
std::copy_n(static_cast<int64_t *>(test_data[batch][1].data.data()) +
100+
element_in_batch,
101+
1, static_cast<int64_t *>(labels.data.data()) + i);
102+
}
103+
104+
auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(2);
105+
(*warmup_data)[0] = std::move(images);
106+
(*warmup_data)[1] = std::move(labels);
107+
return warmup_data;
108+
}
109+
110+
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
111+
int32_t batch_size = FLAGS_batch_size) {
112+
std::ifstream file(FLAGS_infer_data, std::ios::binary);
113+
if (!file) {
114+
FAIL() << "Couldn't open file: " << FLAGS_infer_data;
115+
}
116+
117+
int64_t total_images{0};
118+
file.read(reinterpret_cast<char *>(&total_images), sizeof(total_images));
119+
LOG(INFO) << "Total images in file: " << total_images;
120+
121+
std::vector<int> image_batch_shape{batch_size, 3, 224, 224};
122+
std::vector<int> label_batch_shape{batch_size, 1};
123+
auto labels_offset_in_file =
124+
static_cast<size_t>(file.tellg()) +
125+
sizeof(float) * total_images *
126+
std::accumulate(image_batch_shape.begin() + 1,
127+
image_batch_shape.end(), 1, std::multiplies<int>());
128+
129+
TensorReader<float> image_reader(file, 0, image_batch_shape, "input");
130+
TensorReader<int64_t> label_reader(file, labels_offset_in_file,
131+
label_batch_shape, "label");
132+
133+
auto iterations = total_images / batch_size;
134+
if (FLAGS_iterations > 0 && FLAGS_iterations < iterations)
135+
iterations = FLAGS_iterations;
136+
for (auto i = 0; i < iterations; i++) {
137+
auto images = image_reader.NextBatch();
138+
auto labels = label_reader.NextBatch();
139+
inputs->emplace_back(
140+
std::vector<PaddleTensor>{std::move(images), std::move(labels)});
141+
}
142+
}
143+
144+
TEST(Analyzer_int8_resnet50, quantization) {
145+
AnalysisConfig cfg;
146+
SetConfig(&cfg);
147+
148+
AnalysisConfig q_cfg;
149+
SetConfig(&q_cfg);
150+
151+
std::vector<std::vector<PaddleTensor>> input_slots_all;
152+
SetInput(&input_slots_all, 100);
153+
154+
std::shared_ptr<std::vector<PaddleTensor>> warmup_data =
155+
GetWarmupData(input_slots_all, 100);
156+
157+
q_cfg.EnableMkldnnQuantizer();
158+
q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
159+
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(100);
160+
161+
CompareQuantizedAndAnalysis(
162+
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
163+
reinterpret_cast<const PaddlePredictor::Config *>(&q_cfg),
164+
input_slots_all);
165+
}
166+
167+
TEST(Analyzer_int8_resnet50, profile) {
168+
AnalysisConfig cfg;
169+
SetConfig(&cfg);
170+
171+
std::vector<std::vector<PaddleTensor>> input_slots_all;
172+
SetInput(&input_slots_all);
173+
174+
std::shared_ptr<std::vector<PaddleTensor>> warmup_data =
175+
GetWarmupData(input_slots_all, 100);
176+
177+
cfg.EnableMkldnnQuantizer();
178+
cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
179+
cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(100);
180+
181+
std::vector<PaddleTensor> outputs;
182+
183+
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
184+
input_slots_all, &outputs, FLAGS_num_threads);
185+
}
186+
187+
} // namespace analysis
188+
} // namespace inference
189+
} // namespace paddle

paddle/fluid/inference/tests/api/tester_helper.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ DEFINE_bool(use_analysis, true,
5050
DEFINE_bool(record_benchmark, false,
5151
"Record benchmark after profiling the model");
5252
DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
53+
DEFINE_double(quantized_accuracy, 1e-2, "Result Quantized Accuracy.");
5354
DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch.");
5455

5556
DECLARE_bool(profile);
@@ -58,6 +59,19 @@ DECLARE_int32(paddle_num_threads);
5859
namespace paddle {
5960
namespace inference {
6061

62+
template <typename T>
63+
constexpr paddle::PaddleDType GetPaddleDType();
64+
65+
template <>
66+
constexpr paddle::PaddleDType GetPaddleDType<int64_t>() {
67+
return paddle::PaddleDType::INT64;
68+
}
69+
70+
template <>
71+
constexpr paddle::PaddleDType GetPaddleDType<float>() {
72+
return paddle::PaddleDType::FLOAT32;
73+
}
74+
6175
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
6276
const auto *analysis_config =
6377
reinterpret_cast<const AnalysisConfig *>(config);
@@ -392,6 +406,32 @@ void TestPrediction(const PaddlePredictor::Config *config,
392406
}
393407
}
394408

409+
void CompareTopAccuracy(const std::vector<PaddleTensor> &output_slots1,
410+
const std::vector<PaddleTensor> &output_slots2) {
411+
// first output: avg_cost
412+
if (output_slots1.size() == 0 || output_slots2.size() == 0)
413+
throw std::invalid_argument(
414+
"CompareTopAccuracy: output_slots vector is empty.");
415+
PADDLE_ENFORCE(output_slots1.size() >= 2UL);
416+
PADDLE_ENFORCE(output_slots2.size() >= 2UL);
417+
418+
// second output: acc_top1
419+
if (output_slots1[1].lod.size() > 0 || output_slots2[1].lod.size() > 0)
420+
throw std::invalid_argument(
421+
"CompareTopAccuracy: top1 accuracy output has nonempty LoD.");
422+
if (output_slots1[1].dtype != paddle::PaddleDType::FLOAT32 ||
423+
output_slots2[1].dtype != paddle::PaddleDType::FLOAT32)
424+
throw std::invalid_argument(
425+
"CompareTopAccuracy: top1 accuracy output is of a wrong type.");
426+
float *top1_quantized = static_cast<float *>(output_slots1[1].data.data());
427+
float *top1_reference = static_cast<float *>(output_slots2[1].data.data());
428+
LOG(INFO) << "top1 INT8 accuracy: " << *top1_quantized;
429+
LOG(INFO) << "top1 FP32 accuracy: " << *top1_reference;
430+
LOG(INFO) << "Accepted accuracy drop threshold: " << FLAGS_quantized_accuracy;
431+
CHECK_LE(std::abs(*top1_quantized - *top1_reference),
432+
FLAGS_quantized_accuracy);
433+
}
434+
395435
void CompareDeterministic(
396436
const PaddlePredictor::Config *config,
397437
const std::vector<std::vector<PaddleTensor>> &inputs) {
@@ -421,6 +461,17 @@ void CompareNativeAndAnalysis(
421461
CompareResult(analysis_outputs, native_outputs);
422462
}
423463

464+
void CompareQuantizedAndAnalysis(
465+
const PaddlePredictor::Config *config,
466+
const PaddlePredictor::Config *qconfig,
467+
const std::vector<std::vector<PaddleTensor>> &inputs) {
468+
PrintConfig(config, true);
469+
std::vector<PaddleTensor> analysis_outputs, quantized_outputs;
470+
TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
471+
TestOneThreadPrediction(qconfig, inputs, &quantized_outputs, true);
472+
CompareTopAccuracy(quantized_outputs, analysis_outputs);
473+
}
474+
424475
void CompareNativeAndAnalysis(
425476
PaddlePredictor *native_pred, PaddlePredictor *analysis_pred,
426477
const std::vector<std::vector<PaddleTensor>> &inputs) {

0 commit comments

Comments
 (0)