13
13
// limitations under the License.
14
14
15
15
#include " paddle/fluid/inference/api/api_anakin_engine.h"
16
+
17
+ #ifdef PADDLE_WITH_CUDA
16
18
#include < cuda.h>
19
+ #endif
20
+
21
+ #include < mkl_service.h>
22
+ #include < omp.h>
23
+ #include < map>
24
+ #include < string>
25
+ #include < utility>
17
26
#include < vector>
18
27
28
+ #include " framework/core/net/net.h"
29
+ #include " framework/operators/ops.h"
30
+ #include " saber/funcs/timer.h"
31
+
19
32
namespace paddle {
20
33
21
34
template <typename Target>
22
35
PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
23
36
const AnakinConfig &config) {
24
37
CHECK (Init (config));
25
38
}
26
-
39
+ template <>
40
+ PaddleInferenceAnakinPredictor<anakin::X86>::PaddleInferenceAnakinPredictor(
41
+ const AnakinConfig &config) {
42
+ omp_set_dynamic (0 );
43
+ omp_set_num_threads (1 );
44
+ mkl_set_num_threads (1 );
45
+ CHECK (Init (config));
46
+ }
27
47
template <typename Target>
28
48
bool PaddleInferenceAnakinPredictor<Target>::Init(const AnakinConfig &config) {
29
49
if (!(graph_.load (config.model_file ))) {
30
- LOG (FATAL ) << " fail to load graph from " << config.model_file ;
50
+ VLOG ( 3 ) << " fail to load graph from " << config.model_file ;
31
51
return false ;
32
52
}
33
53
auto inputs = graph_.get_ins ();
34
54
for (auto &input_str : inputs) {
35
55
graph_.ResetBatchSize (input_str, config.max_batch_size );
56
+ max_batch_size_ = config.max_batch_size ;
36
57
}
37
58
// optimization for graph
38
59
if (!(graph_.Optimize ())) {
@@ -52,15 +73,15 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
52
73
std::vector<PaddleTensor> *output_data, int batch_size) {
53
74
for (const auto &input : inputs) {
54
75
if (input.dtype != PaddleDType::FLOAT32) {
55
- LOG (ERROR ) << " Only support float type inputs. " << input.name
56
- << " 's type is not float" ;
76
+ VLOG ( 3 ) << " Only support float type inputs. " << input.name
77
+ << " 's type is not float" ;
57
78
return false ;
58
79
}
59
80
auto d_tensor_in_p = executor_p_->get_in (input.name );
60
- auto net_shape = d_tensor_in_p->valid_shape ();
81
+ auto net_shape = d_tensor_in_p->shape ();
61
82
if (net_shape.size () != input.shape .size ()) {
62
- LOG (ERROR ) << " input " << input.name
63
- << " 's shape size should be equal to that of net" ;
83
+ VLOG ( 3 ) << " input " << input.name
84
+ << " 's shape size should be equal to that of net" ;
64
85
return false ;
65
86
}
66
87
int sum = 1 ;
@@ -79,21 +100,45 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
79
100
}
80
101
d_tensor_in_p->reshape (tmp_shape);
81
102
103
+ if (input.lod .size () > 0 ) {
104
+ if (input.lod .size () > 1 ) {
105
+ VLOG (3 ) << " input lod first dim should <=1, but you set "
106
+ << input.lod .size ();
107
+ return false ;
108
+ }
109
+ std::vector<int > offset (input.lod [0 ].begin (), input.lod [0 ].end ());
110
+ d_tensor_in_p->set_seq_offset (offset);
111
+ VLOG (3 ) << " offset.size(): " << offset.size ();
112
+ for (int i = 0 ; i < offset.size (); i++) {
113
+ VLOG (3 ) << offset[i];
114
+ }
115
+ }
116
+
82
117
float *d_data_p = d_tensor_in_p->mutable_data ();
83
- if (cudaMemcpy (d_data_p, static_cast <float *>(input.data .data ()),
84
- d_tensor_in_p->valid_size () * sizeof (float ),
85
- cudaMemcpyHostToDevice) != 0 ) {
86
- LOG (ERROR) << " copy data from CPU to GPU error" ;
87
- return false ;
118
+
119
+ #ifdef PADDLE_WITH_CUDA
120
+ if (std::is_same<anakin::NV, Target>::value) {
121
+ if (cudaMemcpy (d_data_p, static_cast <float *>(input.data .data ()),
122
+ d_tensor_in_p->valid_size () * sizeof (float ),
123
+ cudaMemcpyHostToDevice) != 0 ) {
124
+ VLOG (3 ) << " copy data from CPU to GPU error" ;
125
+ return false ;
126
+ }
127
+ }
128
+ #endif
129
+ if (std::is_same<anakin::X86, Target>::value) {
130
+ memcpy (d_data_p, static_cast <float *>(input.data .data ()),
131
+ d_tensor_in_p->valid_size () * sizeof (float ));
88
132
}
89
- cudaStreamSynchronize (NULL );
90
133
}
134
+ #ifdef PADDLE_WITH_CUDA
91
135
cudaDeviceSynchronize ();
92
136
executor_p_->prediction ();
93
137
cudaDeviceSynchronize ();
138
+ #endif
94
139
95
140
if (output_data->empty ()) {
96
- LOG (ERROR ) << " At least one output should be set with tensors' names." ;
141
+ VLOG ( 3 ) << " At least one output should be set with tensors' names." ;
97
142
return false ;
98
143
}
99
144
for (auto &output : *output_data) {
@@ -102,14 +147,22 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
102
147
if (output.data .length () < tensor->valid_size () * sizeof (float )) {
103
148
output.data .Resize (tensor->valid_size () * sizeof (float ));
104
149
}
105
- // Copy data from GPU -> CPU
106
- if (cudaMemcpy (output.data .data (), tensor->mutable_data (),
107
- tensor->valid_size () * sizeof (float ),
108
- cudaMemcpyDeviceToHost) != 0 ) {
109
- LOG (ERROR) << " copy data from GPU to CPU error" ;
110
- return false ;
150
+
151
+ #if PADDLE_WITH_CUDA
152
+ if (std::is_same<anakin::NV, Target>::value) {
153
+ // Copy data from GPU -> CPU
154
+ if (cudaMemcpy (output.data .data (), tensor->mutable_data (),
155
+ tensor->valid_size () * sizeof (float ),
156
+ cudaMemcpyDeviceToHost) != 0 ) {
157
+ VLOG (3 ) << " copy data from GPU to CPU error" ;
158
+ return false ;
159
+ }
160
+ }
161
+ #endif
162
+ if (std::is_same<anakin::X86, Target>::value) {
163
+ memcpy (output.data .data (), tensor->mutable_data (),
164
+ tensor->valid_size () * sizeof (float ));
111
165
}
112
- cudaStreamSynchronize (NULL );
113
166
}
114
167
return true ;
115
168
}
@@ -132,7 +185,7 @@ PaddleInferenceAnakinPredictor<Target>::Clone() {
132
185
auto anakin_predictor_p =
133
186
dynamic_cast <PaddleInferenceAnakinPredictor<Target> *>(cls.get ());
134
187
if (!anakin_predictor_p) {
135
- LOG (ERROR ) << " fail to call Init" ;
188
+ VLOG ( 3 ) << " fail to call Init" ;
136
189
return nullptr ;
137
190
}
138
191
anakin_predictor_p->get_executer ().init (graph_);
@@ -162,6 +215,44 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
162
215
VLOG (3 ) << " Anakin Predictor create on unknown platform." ;
163
216
return nullptr ;
164
217
}
165
- };
218
+ }
219
+
220
+ #ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
221
+ template <typename Target>
222
+ using executor_t =
223
+ anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>;
224
+
225
+ template <typename Target>
226
+ void DisplayOpTimer (executor_t <Target> *net_executor, int epoch) {
227
+ std::vector<float > op_time = net_executor->get_op_time ();
228
+ auto exec_funcs = net_executor->get_exec_funcs ();
229
+ auto op_param = net_executor->get_op_param ();
230
+ for (int i = 0 ; i < op_time.size (); i++) {
231
+ LOG (INFO) << " name: " << exec_funcs[i].name
232
+ << " op_type: " << exec_funcs[i].op_name
233
+ << " op_param: " << op_param[i] << " time " << op_time[i] / epoch;
234
+ }
235
+ std::map<std::string, float > op_map;
236
+ for (int i = 0 ; i < op_time.size (); i++) {
237
+ auto it = op_map.find (op_param[i]);
238
+ if (it != op_map.end ())
239
+ op_map[op_param[i]] += op_time[i];
240
+ else
241
+ op_map.insert (std::pair<std::string, float >(op_param[i], op_time[i]));
242
+ }
243
+ for (auto it = op_map.begin (); it != op_map.end (); ++it) {
244
+ LOG (INFO) << it->first << " " << (it->second ) / epoch << " ms" ;
245
+ }
246
+ }
247
+ #endif
248
+
249
+ template <typename Target>
250
+ PaddleInferenceAnakinPredictor<Target>::~PaddleInferenceAnakinPredictor () {
251
+ #ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
252
+ DisplayOpTimer<Target>(executor_p_, max_batch_size_);
253
+ #endif
254
+ delete executor_p_;
255
+ executor_p_ = nullptr ;
256
+ }
166
257
167
258
} // namespace paddle
0 commit comments