18
18
19
19
namespace paddle {
20
20
21
- PaddleInferenceAnakinPredictor::PaddleInferenceAnakinPredictor (
21
+ template <typename Target>
22
+ PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
22
23
const AnakinConfig &config) {
23
24
CHECK (Init (config));
24
25
}
25
26
26
- bool PaddleInferenceAnakinPredictor::Init (const AnakinConfig &config) {
27
+ template <typename Target>
28
+ bool PaddleInferenceAnakinPredictor<Target>::Init(const AnakinConfig &config) {
27
29
if (!(graph_.load (config.model_file ))) {
30
+ LOG (FATAL) << " fail to load graph from " << config.model_file ;
28
31
return false ;
29
32
}
30
- graph_.ResetBatchSize (" input_0" , config.max_batch_size );
33
+ auto inputs = graph_.get_ins ();
34
+ for (auto &input_str : inputs) {
35
+ graph_.ResetBatchSize (input_str, config.max_batch_size );
36
+ }
31
37
// optimization for graph
32
38
if (!(graph_.Optimize ())) {
33
39
return false ;
34
40
}
35
41
// construct executer
36
- executor_.init (graph_);
42
+ if (executor_p_ == nullptr ) {
43
+ executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
44
+ anakin::Precision::FP32>(graph_, true );
45
+ }
37
46
return true ;
38
47
}
39
48
40
- bool PaddleInferenceAnakinPredictor::Run (
49
+ template <typename Target>
50
+ bool PaddleInferenceAnakinPredictor<Target>::Run(
41
51
const std::vector<PaddleTensor> &inputs,
42
52
std::vector<PaddleTensor> *output_data, int batch_size) {
43
53
for (const auto &input : inputs) {
@@ -46,7 +56,29 @@ bool PaddleInferenceAnakinPredictor::Run(
46
56
<< " 's type is not float" ;
47
57
return false ;
48
58
}
49
- auto d_tensor_in_p = executor_.get_in (input.name );
59
+ auto d_tensor_in_p = executor_p_->get_in (input.name );
60
+ auto net_shape = d_tensor_in_p->valid_shape ();
61
+ if (net_shape.size () != input.shape .size ()) {
62
+ LOG (ERROR) << " input " << input.name
63
+ << " 's shape size should be equal to that of net" ;
64
+ return false ;
65
+ }
66
+ int sum = 1 ;
67
+ for_each (input.shape .begin (), input.shape .end (), [&](int n) { sum *= n; });
68
+ if (sum > net_shape.count ()) {
69
+ graph_.Reshape (input.name , input.shape );
70
+ delete executor_p_;
71
+ executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
72
+ anakin::Precision::FP32>(graph_, true );
73
+ d_tensor_in_p = executor_p_->get_in (input.name );
74
+ }
75
+
76
+ anakin::saber::Shape tmp_shape;
77
+ for (auto s : input.shape ) {
78
+ tmp_shape.push_back (s);
79
+ }
80
+ d_tensor_in_p->reshape (tmp_shape);
81
+
50
82
float *d_data_p = d_tensor_in_p->mutable_data ();
51
83
if (cudaMemcpy (d_data_p, static_cast <float *>(input.data .data ()),
52
84
d_tensor_in_p->valid_size () * sizeof (float ),
@@ -56,16 +88,17 @@ bool PaddleInferenceAnakinPredictor::Run(
56
88
}
57
89
cudaStreamSynchronize (NULL );
58
90
}
59
-
60
- executor_.prediction ();
91
+ cudaDeviceSynchronize ();
92
+ executor_p_->prediction ();
93
+ cudaDeviceSynchronize ();
61
94
62
95
if (output_data->empty ()) {
63
96
LOG (ERROR) << " At least one output should be set with tensors' names." ;
64
97
return false ;
65
98
}
66
99
for (auto &output : *output_data) {
67
- auto *tensor = executor_. get_out (output.name );
68
- output.shape = tensor->shape ();
100
+ auto *tensor = executor_p_-> get_out (output.name );
101
+ output.shape = tensor->valid_shape ();
69
102
if (output.data .length () < tensor->valid_size () * sizeof (float )) {
70
103
output.data .Resize (tensor->valid_size () * sizeof (float ));
71
104
}
@@ -81,19 +114,23 @@ bool PaddleInferenceAnakinPredictor::Run(
81
114
return true ;
82
115
}
83
116
84
- anakin::Net<anakin::NV, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
85
- &PaddleInferenceAnakinPredictor::get_executer () {
86
- return executor_;
117
+ template <typename Target>
118
+ anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
119
+ &PaddleInferenceAnakinPredictor<Target>::get_executer() {
120
+ return *executor_p_;
87
121
}
88
122
89
123
// the cloned new Predictor of anakin share the same net weights from original
90
124
// Predictor
91
- std::unique_ptr<PaddlePredictor> PaddleInferenceAnakinPredictor::Clone () {
125
+ template <typename Target>
126
+ std::unique_ptr<PaddlePredictor>
127
+ PaddleInferenceAnakinPredictor<Target>::Clone() {
92
128
VLOG (3 ) << " Anakin Predictor::clone" ;
93
- std::unique_ptr<PaddlePredictor> cls (new PaddleInferenceAnakinPredictor ());
129
+ std::unique_ptr<PaddlePredictor> cls (
130
+ new PaddleInferenceAnakinPredictor<Target>());
94
131
// construct executer from other graph
95
132
auto anakin_predictor_p =
96
- dynamic_cast <PaddleInferenceAnakinPredictor *>(cls.get ());
133
+ dynamic_cast <PaddleInferenceAnakinPredictor<Target> *>(cls.get ());
97
134
if (!anakin_predictor_p) {
98
135
LOG (ERROR) << " fail to call Init" ;
99
136
return nullptr ;
@@ -103,14 +140,28 @@ std::unique_ptr<PaddlePredictor> PaddleInferenceAnakinPredictor::Clone() {
103
140
return std::move (cls);
104
141
}
105
142
143
+ template class PaddleInferenceAnakinPredictor <anakin::NV>;
144
+ template class PaddleInferenceAnakinPredictor <anakin::X86>;
145
+
106
146
// A factory to help create difference predictor.
107
147
template <>
108
148
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
109
149
AnakinConfig, PaddleEngineKind::kAnakin >(const AnakinConfig &config) {
110
150
VLOG (3 ) << " Anakin Predictor create." ;
111
- std::unique_ptr<PaddlePredictor> x (
112
- new PaddleInferenceAnakinPredictor (config));
113
- return x;
114
- }
151
+ if (config.target_type == AnakinConfig::NVGPU) {
152
+ VLOG (3 ) << " Anakin Predictor create on [ NVIDIA GPU ]." ;
153
+ std::unique_ptr<PaddlePredictor> x (
154
+ new PaddleInferenceAnakinPredictor<anakin::NV>(config));
155
+ return x;
156
+ } else if (config.target_type == AnakinConfig::X86) {
157
+ VLOG (3 ) << " Anakin Predictor create on [ Intel X86 ]." ;
158
+ std::unique_ptr<PaddlePredictor> x (
159
+ new PaddleInferenceAnakinPredictor<anakin::X86>(config));
160
+ return x;
161
+ } else {
162
+ VLOG (3 ) << " Anakin Predictor create on unknown platform." ;
163
+ return nullptr ;
164
+ }
165
+ };
115
166
116
167
} // namespace paddle
0 commit comments