@@ -12,22 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " paddle/fluid/inference/analysis/analyzer.h"
16
- #include < gflags/gflags.h>
17
- #include < glog/logging.h>
18
- #include < gtest/gtest.h>
19
15
#include < fstream>
20
16
#include < iostream>
21
- #include " paddle/fluid/framework/ir/fuse_pass_base.h"
22
- #include " paddle/fluid/inference/analysis/ut_helper.h"
23
- #include " paddle/fluid/inference/api/analysis_predictor.h"
24
- #include " paddle/fluid/inference/api/helper.h"
25
- #include " paddle/fluid/inference/api/paddle_inference_pass.h"
26
-
27
- DEFINE_string (infer_model, " " , " model path for LAC" );
28
- DEFINE_string (infer_data, " " , " data file for LAC" );
29
- DEFINE_int32 (batch_size, 1 , " batch size." );
30
- DEFINE_int32 (repeat, 1 , " Running the inference program repeat times." );
17
+ #include " paddle/fluid/inference/tests/api/tester_helper.h"
31
18
32
19
namespace paddle {
33
20
namespace inference {
@@ -105,69 +92,36 @@ void TestVisualPrediction(bool use_mkldnn) {
105
92
VLOG (3 ) << " output.size " << outputs_slots.size ();
106
93
107
94
// run native as reference
108
- NativeConfig config;
109
- config.param_file = FLAGS_infer_model + " /__params__" ;
110
- config.prog_file = FLAGS_infer_model + " /__model__" ;
111
- config.use_gpu = false ;
112
- config.device = 0 ;
113
- // config.specify_input_name = true;
114
95
auto ref_predictor =
115
- CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative >(config );
96
+ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative >(cfg );
116
97
std::vector<PaddleTensor> ref_outputs_slots;
117
98
ref_predictor->Run ({input}, &ref_outputs_slots);
118
- EXPECT_EQ (ref_outputs_slots.size (), outputs_slots.size ());
119
- for (size_t i = 0 ; i < outputs_slots.size (); ++i) {
120
- auto &ref_out = ref_outputs_slots[i];
121
- auto &out = outputs_slots[i];
122
- size_t ref_size =
123
- std::accumulate (ref_out.shape .begin (), ref_out.shape .end (), 1 ,
124
- [](int a, int b) { return a * b; });
125
- size_t size = std::accumulate (out.shape .begin (), out.shape .end (), 1 ,
126
- [](int a, int b) { return a * b; });
127
- EXPECT_EQ (size, ref_size);
128
- EXPECT_EQ (out.dtype , ref_out.dtype );
129
- switch (out.dtype ) {
130
- case PaddleDType::INT64: {
131
- int64_t *pdata = static_cast <int64_t *>(out.data .data ());
132
- int64_t *pdata_ref = static_cast <int64_t *>(ref_out.data .data ());
133
- for (size_t j = 0 ; j < size; ++j) {
134
- EXPECT_EQ (pdata_ref[j], pdata[j]);
135
- }
136
- break ;
137
- }
138
- case PaddleDType::FLOAT32: {
139
- float *pdata = static_cast <float *>(out.data .data ());
140
- float *pdata_ref = static_cast <float *>(ref_out.data .data ());
141
- for (size_t j = 0 ; j < size; ++j) {
142
- EXPECT_NEAR (pdata_ref[j], pdata[j], 1e-3 );
143
- }
144
- break ;
145
- }
146
- }
147
- // print what are fused
148
- AnalysisPredictor *analysis_predictor =
149
- dynamic_cast <AnalysisPredictor *>(predictor.get ());
150
- auto &fuse_statis = analysis_predictor->analysis_argument ()
151
- .Get <std::unordered_map<std::string, int >>(
152
- framework::ir::kFuseStatisAttr );
153
- for (auto &item : fuse_statis) {
154
- LOG (INFO) << " fused " << item.first << " " << item.second ;
155
- }
156
- int num_ops = 0 ;
157
- for (auto &node :
158
- analysis_predictor->analysis_argument ().main_dfg ->nodes .nodes ()) {
159
- if (node->IsFunction ()) {
160
- ++num_ops;
161
- }
99
+ CompareResult (outputs_slots, ref_outputs_slots);
100
+ // print what are fused
101
+ AnalysisPredictor *analysis_predictor =
102
+ dynamic_cast <AnalysisPredictor *>(predictor.get ());
103
+ auto &fuse_statis = analysis_predictor->analysis_argument ()
104
+ .Get <std::unordered_map<std::string, int >>(
105
+ framework::ir::kFuseStatisAttr );
106
+ for (auto &item : fuse_statis) {
107
+ LOG (INFO) << " fused " << item.first << " " << item.second ;
108
+ }
109
+ int num_ops = 0 ;
110
+ for (auto &node :
111
+ analysis_predictor->analysis_argument ().main_dfg ->nodes .nodes ()) {
112
+ if (node->IsFunction ()) {
113
+ ++num_ops;
162
114
}
163
- LOG (INFO) << " has num ops: " << num_ops;
164
115
}
116
+ LOG (INFO) << " has num ops: " << num_ops;
165
117
}
166
118
167
119
TEST (Analyzer_vis, analysis) { TestVisualPrediction (/* use_mkldnn*/ false ); }
120
+ #ifdef PADDLE_WITH_MKLDNN
168
121
TEST (Analyzer_vis, analysis_mkldnn) {
169
122
TestVisualPrediction (/* use_mkldnn*/ true );
170
123
}
124
+ #endif
171
125
172
126
} // namespace analysis
173
127
} // namespace inference
0 commit comments