Skip to content

Commit 1a99302

Browse files
committed
refine and reuse code
1 parent b7a64e8 commit 1a99302

File tree

3 files changed

+48
-79
lines changed

3 files changed

+48
-79
lines changed

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@ if (NOT EXISTS ${OCR_INSTALL_DIR} AND WITH_INFERENCE)
6868
message(STATUS "finish downloading ${filename}")
6969
endif()
7070
inference_analysis_test(test_analyzer_ocr SRCS analyzer_vis_tester.cc
71-
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
71+
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
7272
ARGS --infer_model=${OCR_INSTALL_DIR}/model
7373
--infer_data=${OCR_INSTALL_DIR}/data.txt)

paddle/fluid/inference/tests/api/analyzer_vis_tester.cc

Lines changed: 20 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/inference/analysis/analyzer.h"
16-
#include <gflags/gflags.h>
17-
#include <glog/logging.h>
18-
#include <gtest/gtest.h>
1915
#include <fstream>
2016
#include <iostream>
21-
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
22-
#include "paddle/fluid/inference/analysis/ut_helper.h"
23-
#include "paddle/fluid/inference/api/analysis_predictor.h"
24-
#include "paddle/fluid/inference/api/helper.h"
25-
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
26-
27-
DEFINE_string(infer_model, "", "model path for LAC");
28-
DEFINE_string(infer_data, "", "data file for LAC");
29-
DEFINE_int32(batch_size, 1, "batch size.");
30-
DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
17+
#include "paddle/fluid/inference/tests/api/tester_helper.h"
3118

3219
namespace paddle {
3320
namespace inference {
@@ -105,69 +92,36 @@ void TestVisualPrediction(bool use_mkldnn) {
10592
VLOG(3) << "output.size " << outputs_slots.size();
10693

10794
// run native as reference
108-
NativeConfig config;
109-
config.param_file = FLAGS_infer_model + "/__params__";
110-
config.prog_file = FLAGS_infer_model + "/__model__";
111-
config.use_gpu = false;
112-
config.device = 0;
113-
// config.specify_input_name = true;
11495
auto ref_predictor =
115-
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
96+
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
11697
std::vector<PaddleTensor> ref_outputs_slots;
11798
ref_predictor->Run({input}, &ref_outputs_slots);
118-
EXPECT_EQ(ref_outputs_slots.size(), outputs_slots.size());
119-
for (size_t i = 0; i < outputs_slots.size(); ++i) {
120-
auto &ref_out = ref_outputs_slots[i];
121-
auto &out = outputs_slots[i];
122-
size_t ref_size =
123-
std::accumulate(ref_out.shape.begin(), ref_out.shape.end(), 1,
124-
[](int a, int b) { return a * b; });
125-
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
126-
[](int a, int b) { return a * b; });
127-
EXPECT_EQ(size, ref_size);
128-
EXPECT_EQ(out.dtype, ref_out.dtype);
129-
switch (out.dtype) {
130-
case PaddleDType::INT64: {
131-
int64_t *pdata = static_cast<int64_t *>(out.data.data());
132-
int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data());
133-
for (size_t j = 0; j < size; ++j) {
134-
EXPECT_EQ(pdata_ref[j], pdata[j]);
135-
}
136-
break;
137-
}
138-
case PaddleDType::FLOAT32: {
139-
float *pdata = static_cast<float *>(out.data.data());
140-
float *pdata_ref = static_cast<float *>(ref_out.data.data());
141-
for (size_t j = 0; j < size; ++j) {
142-
EXPECT_NEAR(pdata_ref[j], pdata[j], 1e-3);
143-
}
144-
break;
145-
}
146-
}
147-
// print what are fused
148-
AnalysisPredictor *analysis_predictor =
149-
dynamic_cast<AnalysisPredictor *>(predictor.get());
150-
auto &fuse_statis = analysis_predictor->analysis_argument()
151-
.Get<std::unordered_map<std::string, int>>(
152-
framework::ir::kFuseStatisAttr);
153-
for (auto &item : fuse_statis) {
154-
LOG(INFO) << "fused " << item.first << " " << item.second;
155-
}
156-
int num_ops = 0;
157-
for (auto &node :
158-
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
159-
if (node->IsFunction()) {
160-
++num_ops;
161-
}
99+
CompareResult(outputs_slots, ref_outputs_slots);
100+
// print what are fused
101+
AnalysisPredictor *analysis_predictor =
102+
dynamic_cast<AnalysisPredictor *>(predictor.get());
103+
auto &fuse_statis = analysis_predictor->analysis_argument()
104+
.Get<std::unordered_map<std::string, int>>(
105+
framework::ir::kFuseStatisAttr);
106+
for (auto &item : fuse_statis) {
107+
LOG(INFO) << "fused " << item.first << " " << item.second;
108+
}
109+
int num_ops = 0;
110+
for (auto &node :
111+
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
112+
if (node->IsFunction()) {
113+
++num_ops;
162114
}
163-
LOG(INFO) << "has num ops: " << num_ops;
164115
}
116+
LOG(INFO) << "has num ops: " << num_ops;
165117
}
166118

167119
TEST(Analyzer_vis, analysis) { TestVisualPrediction(/*use_mkldnn*/ false); }
120+
#ifdef PADDLE_WITH_MKLDNN
168121
TEST(Analyzer_vis, analysis_mkldnn) {
169122
TestVisualPrediction(/*use_mkldnn*/ true);
170123
}
124+
#endif
171125

172126
} // namespace analysis
173127
} // namespace inference

paddle/fluid/inference/tests/api/tester_helper.h

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,37 @@ namespace paddle {
3737
namespace inference {
3838

3939
void CompareResult(const std::vector<PaddleTensor> &outputs,
40-
const std::vector<PaddleTensor> &base_outputs) {
41-
PADDLE_ENFORCE_GT(outputs.size(), 0);
42-
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
40+
const std::vector<PaddleTensor> &ref_outputs) {
41+
EXPECT_GT(outputs.size(), 0);
42+
EXPECT_EQ(outputs.size(), ref_outputs.size());
4343
for (size_t i = 0; i < outputs.size(); i++) {
4444
auto &out = outputs[i];
45-
auto &base_out = base_outputs[i];
45+
auto &ref_out = ref_outputs[i];
4646
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
4747
[](int a, int b) { return a * b; });
48-
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
49-
1, [](int a, int b) { return a * b; });
50-
PADDLE_ENFORCE_EQ(size, size1);
51-
PADDLE_ENFORCE_GT(size, 0);
52-
float *data = static_cast<float *>(out.data.data());
53-
float *base_data = static_cast<float *>(base_out.data.data());
54-
for (size_t i = 0; i < size; i++) {
55-
EXPECT_NEAR(data[i], base_data[i], 1e-3);
48+
size_t ref_size =
49+
std::accumulate(ref_out.shape.begin(), ref_out.shape.end(), 1,
50+
[](int a, int b) { return a * b; });
51+
EXPECT_GT(size, 0);
52+
EXPECT_EQ(size, ref_size);
53+
EXPECT_EQ(out.dtype, ref_out.dtype);
54+
switch (out.dtype) {
55+
case PaddleDType::INT64: {
56+
int64_t *pdata = static_cast<int64_t *>(out.data.data());
57+
int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data());
58+
for (size_t j = 0; j < size; ++j) {
59+
EXPECT_EQ(pdata_ref[j], pdata[j]);
60+
}
61+
break;
62+
}
63+
case PaddleDType::FLOAT32: {
64+
float *pdata = static_cast<float *>(out.data.data());
65+
float *pdata_ref = static_cast<float *>(ref_out.data.data());
66+
for (size_t j = 0; j < size; ++j) {
67+
EXPECT_NEAR(pdata_ref[j], pdata[j], 1e-3);
68+
}
69+
break;
70+
}
5671
}
5772
}
5873
}

0 commit comments

Comments
 (0)