Skip to content

Commit a06883c

Browse files
authored
Golang inference API (#22503) (#22601)
* support golang inference
1 parent edcf04c commit a06883c

20 files changed

+1102
-27
lines changed

cmake/inference_lib.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ else(WIN32)
190190
endif(WIN32)
191191

192192
copy(inference_lib_dist
193-
SRCS ${src_dir}/inference/capi/c_api.h ${paddle_fluid_c_lib}
193+
SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_fluid_c_lib}
194194
DSTS ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/lib)
195195

196196
# fluid library for both train and inference

go/README_cn.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Paddle 预测golang API
2+
3+
## 安装
4+
首先cmake编译时打开`-DON_INFER=ON`,在编译目录下得到``fluid_inference_c_install_dir``,将该目录移动到当前目录中并重命名为`paddle_c`
5+
6+
## 在Go中使用Paddle预测
7+
首先创建预测配置
8+
``` go
9+
config := paddle.NewAnalysisConfig()
10+
config.SetModel(model_file, params_file)
11+
config.SwitchUseFeedFetchOps(false)
12+
config.SwitchSpecifyInputNames(true)
13+
```
14+
15+
创建predictor
16+
``` go
17+
predictor := paddle.NewPredictor(config)
18+
```
19+
20+
获取输入Tensor和输出Tensor
21+
``` go
22+
inputs = predictor.GetInputTensors()
23+
```
24+
25+
设置输入数据(假设只有一个输入)
26+
``` go
27+
input := inputs[0]
28+
input.SetValue(data)
29+
input.Reshape([]int32{1, 3, 300, 300})
30+
```
31+
32+
运行预测
33+
``` go
34+
predictor.ZeroCopyRun()
35+
```
36+
37+
获取输入Tensor的真实值
38+
``` go
39+
output := outputs[0]
40+
predictor.GetZeroCopyOutput(output)
41+
value := reflect.ValueOf(output.Value())
42+
shape, dtype := paddle.ShapeAndTypeOf(value)
43+
output_data := value.Interface().([][]float32)
44+
```
45+
46+
## 示例
47+
源码见[mobilenet](./demo/mobilenet.go)
48+
49+
下载[数据](https://paddle-inference-dist.cdn.bcebos.com/mobilenet-test-model-data.tar.gz)并解压到当前目录
50+
51+
运行
52+
``` go
53+
go run ./demo/mobilenet.go
54+
```

go/demo/mobilenet.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package main
15+
16+
import "../paddle"
17+
import "strings"
18+
import "io/ioutil"
19+
import "strconv"
20+
import "reflect"
21+
22+
func main() {
23+
config := paddle.NewAnalysisConfig()
24+
config.SetModel("data/model/__model__", "data/model/__params__")
25+
config.DisableGlogInfo()
26+
config.SwitchUseFeedFetchOps(false)
27+
config.SwitchSpecifyInputNames(true)
28+
29+
predictor := paddle.NewPredictor(config)
30+
31+
println("============== paddle inference ==============")
32+
println("input num: ", predictor.GetInputNum())
33+
println("input name: ", predictor.GetInputNames()[0])
34+
println("output num: ", predictor.GetOutputNum())
35+
println("output name: ", predictor.GetInputNames()[0])
36+
println("============== run inference =================")
37+
38+
input := predictor.GetInputTensors()[0]
39+
output := predictor.GetOutputTensors()[0]
40+
41+
filename := "data/data.txt"
42+
data := ReadData(filename)
43+
input.SetValue(data[:1 * 3 * 300 * 300])
44+
input.Reshape([]int32{1, 3, 300, 300})
45+
46+
predictor.SetZeroCopyInput(input)
47+
predictor.ZeroCopyRun()
48+
predictor.GetZeroCopyOutput(output)
49+
50+
println("============= parse output ===================")
51+
output_val := output.Value()
52+
value := reflect.ValueOf(output_val)
53+
shape, dtype := paddle.ShapeAndTypeOf(value)
54+
switch dtype {
55+
case paddle.PaddleDType(paddle.FLOAT32):
56+
v := value.Interface().([][]float32)
57+
println("v: ", v[0][0], v[0][1], "...")
58+
case paddle.PaddleDType(paddle.UINT8):
59+
v := value.Interface().([][]uint8)
60+
println("v: ", v[0][0], v[0][1], "...")
61+
case paddle.PaddleDType(paddle.INT32):
62+
v := value.Interface().([][]int32)
63+
println("v: ", v[0][0], v[0][1], "...")
64+
case paddle.PaddleDType(paddle.INT64):
65+
v := value.Interface().([][]int64)
66+
println("v: ", v[0][0], v[0][1], "...")
67+
}
68+
println(shape[0], shape[1])
69+
println(output.Shape()[0])
70+
}
71+
72+
func ReadData(filename string) []float32 {
73+
file_bytes, _ := ioutil.ReadFile(filename)
74+
data_slice := strings.Split(string(file_bytes), " ")
75+
var result []float32
76+
for _, n := range data_slice {
77+
r, _ := strconv.ParseFloat(n, 32)
78+
result = append(result, float32(r))
79+
}
80+
return result
81+
}

go/demo/mobilenet_c.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include <paddle_c_api.h>
15+
#include <stdio.h>
16+
#include <stdlib.h>
17+
18+
void SetConfig(PD_AnalysisConfig *);
19+
void ReadData(float *data, int size);
20+
21+
int main(int argc, char *argv[]) {
22+
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
23+
SetConfig(config);
24+
PD_Predictor *predictor = PD_NewPredictor(config);
25+
26+
int input_num = PD_GetInputNum(predictor);
27+
printf("Input num: %d\n", input_num);
28+
int output_num = PD_GetOutputNum(predictor);
29+
printf("Output num: %d\n", output_num);
30+
31+
PD_ZeroCopyTensor input;
32+
input.name = const_cast<char *>(PD_GetInputName(predictor, 0)); // NOLINT
33+
input.data.capacity = sizeof(float) * 1 * 3 * 300 * 300;
34+
input.data.length = input.data.capacity;
35+
input.data.data = malloc(input.data.capacity);
36+
int shape[] = {1, 3, 300, 300};
37+
input.shape.data = static_cast<int *>(shape);
38+
input.shape.capacity = sizeof(shape);
39+
input.shape.length = sizeof(shape);
40+
input.dtype = PD_FLOAT32;
41+
ReadData((float *)input.data.data, 1 * 3 * 300 * 300); // NOLINT
42+
float *data = (float *)input.data.data; // NOLINT
43+
PD_SetZeroCopyInput(predictor, &input);
44+
int *shape_ptr = (int *)input.shape.data; // NOLINT
45+
46+
PD_ZeroCopyRun(predictor);
47+
PD_ZeroCopyTensor output;
48+
PD_InitZeroCopyTensor(&output);
49+
output.name = const_cast<char *>(PD_GetOutputName(predictor, 0)); // NOLINT
50+
PD_GetZeroCopyOutput(predictor, &output);
51+
PD_DestroyZeroCopyTensor(&output);
52+
53+
PD_DeleteAnalysisConfig(config);
54+
PD_DeletePredictor(predictor);
55+
return 0;
56+
}
57+
58+
void SetConfig(PD_AnalysisConfig *config) {
59+
PD_SetModel(config, "data/model/__model__", "data/model/__params__");
60+
PD_SwitchUseFeedFetchOps(config, false);
61+
PD_SwitchSpecifyInputNames(config, true);
62+
PD_DisableGlogInfo(config);
63+
// PD_SwitchIrOptim(config, false);
64+
}
65+
66+
void ReadData(float *data, int n) {
67+
FILE *fp = fopen("data/data.txt", "r");
68+
for (int i = 0; i < n; i++) {
69+
fscanf(fp, "%f", &data[i]);
70+
}
71+
fclose(fp);
72+
}

go/demo/mobilenet_cxx.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include <paddle_inference_api.h>
15+
#include <fstream>
16+
#include <iostream>
17+
18+
void SetConfig(paddle::AnalysisConfig *);
19+
20+
int main(int argc, char *argv[]) {
21+
paddle::AnalysisConfig config;
22+
SetConfig(&config);
23+
auto predictor = paddle::CreatePaddlePredictor(config);
24+
auto input_name = predictor->GetInputNames()[0];
25+
auto input = predictor->GetInputTensor(input_name);
26+
std::cout << predictor->GetOutputNames()[0] << std::endl;
27+
std::vector<int> shape{1, 3, 300, 300};
28+
input->Reshape(std::move(shape));
29+
std::vector<float> data(1 * 300 * 300 * 3);
30+
std::ifstream fin("data/data.txt");
31+
for (int i = 0; i < data.size(); i++) {
32+
fin >> data[i];
33+
}
34+
35+
input->copy_from_cpu(data.data());
36+
predictor->ZeroCopyRun();
37+
auto output_name = predictor->GetOutputNames()[0];
38+
output = predictor->GetOutputTensor(output_name);
39+
return 0;
40+
}
41+
42+
void SetConfig(paddle::AnalysisConfig *config) {
43+
config->SetModel("data/model/__model__", "data/model/__params__");
44+
config->SwitchUseFeedFetchOps(true);
45+
config->SwitchSpecifyInputNames(true);
46+
config->SwitchIrOptim(false);
47+
}

go/paddle/common.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package paddle
16+
17+
// #cgo CFLAGS: -Ipaddle_c/paddle/include
18+
// #cgo LDFLAGS: -Lpaddle_c/paddle/lib -lpaddle_fluid_c
19+
// #include <stdbool.h>
20+
// #include <paddle_c_api.h>
21+
import "C"
22+
import "fmt"
23+
24+
func ConvertCBooleanToGo(b C.bool) bool {
25+
var c_false C.bool
26+
if b != c_false {
27+
return true
28+
}
29+
return false
30+
}
31+
32+
func numel(shape []int32) int32 {
33+
n := int32(1)
34+
for _, d := range shape {
35+
n *= d
36+
}
37+
return n
38+
}
39+
40+
func bug(format string, args ...interface{}) error {
41+
return fmt.Errorf("Bug %v", fmt.Sprintf(format, args...))
42+
}

0 commit comments

Comments
 (0)