@@ -29,25 +29,7 @@ void SetConfig(AnalysisConfig *cfg) {
29
29
}
30
30
31
31
void SetInput (std::vector<std::vector<PaddleTensor>> *inputs) {
32
- PADDLE_ENFORCE_EQ (FLAGS_test_all_data, 0 , " Only have single batch of data." );
33
-
34
- PaddleTensor input;
35
- // channel=3, height/width=318
36
- std::vector<int > shape ({FLAGS_batch_size, 3 , 318 , 318 });
37
- input.shape = shape;
38
- input.dtype = PaddleDType::FLOAT32;
39
-
40
- // fill input data, for profile easily, do not use random data here.
41
- size_t size = FLAGS_batch_size * 3 * 318 * 318 ;
42
- input.data .Resize (size * sizeof (float ));
43
- float *input_data = static_cast <float *>(input.data .data ());
44
- for (size_t i = 0 ; i < size; i++) {
45
- *(input_data + i) = static_cast <float >(i) / size;
46
- }
47
-
48
- std::vector<PaddleTensor> input_slots;
49
- input_slots.assign ({input});
50
- (*inputs).emplace_back (input_slots);
32
+ SetFakeImageInput (inputs, FLAGS_infer_model);
51
33
}
52
34
53
35
// Easy for profiling independently.
@@ -60,21 +42,14 @@ void profile(bool use_mkldnn = false) {
60
42
std::vector<std::vector<PaddleTensor>> input_slots_all;
61
43
SetInput (&input_slots_all);
62
44
TestPrediction (cfg, input_slots_all, &outputs, FLAGS_num_threads);
63
-
64
- if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
65
- PADDLE_ENFORCE_EQ (outputs.size (), 1UL );
66
- size_t size = GetSize (outputs[0 ]);
67
- // output is a 1000-dimension feature
68
- EXPECT_EQ (size, 1000 * FLAGS_batch_size);
69
- }
70
45
}
71
46
72
47
TEST (Analyzer_mobilenet, profile) { profile (); }
73
48
#ifdef PADDLE_WITH_MKLDNN
74
49
TEST (Analyzer_mobilenet, profile_mkldnn) { profile (true /* use_mkldnn */ ); }
75
50
#endif
76
51
77
- // Check the depthwise_conv status
52
+ // Check the depthwise_conv pass status
78
53
TEST (Analyzer_mobilenet, depthwise_conv_statis) {
79
54
AnalysisConfig cfg;
80
55
SetConfig (&cfg);
@@ -83,8 +58,7 @@ TEST(Analyzer_mobilenet, depthwise_conv_statis) {
83
58
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
84
59
auto fuse_statis = GetFuseStatis (
85
60
static_cast <AnalysisPredictor *>(predictor.get ()), &num_ops);
86
- ASSERT_TRUE (fuse_statis.count (" depthwise_conv_mkldnn_pass" ));
87
- EXPECT_EQ (fuse_statis.at (" depthwise_conv_mkldnn_pass" ), 13 );
61
+ LOG (INFO) << " num_ops: " << num_ops;
88
62
}
89
63
90
64
// Compare result of NativeConfig and AnalysisConfig
0 commit comments