Skip to content

Commit 43dd30b

Browse files
authored
fix go api (#22670)
1 parent f2fe531 commit 43dd30b

File tree

6 files changed

+36
-18
lines changed

6 files changed

+36
-18
lines changed

go/demo/mobilenet_c.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ int main(int argc, char *argv[]) {
2929
printf("Output num: %d\n", output_num);
3030

3131
PD_ZeroCopyTensor input;
32+
PD_InitZeroCopyTensor(&input);
3233
input.name = const_cast<char *>(PD_GetInputName(predictor, 0)); // NOLINT
3334
input.data.capacity = sizeof(float) * 1 * 3 * 300 * 300;
3435
input.data.length = input.data.capacity;
@@ -48,6 +49,7 @@ int main(int argc, char *argv[]) {
4849
PD_InitZeroCopyTensor(&output);
4950
output.name = const_cast<char *>(PD_GetOutputName(predictor, 0)); // NOLINT
5051
PD_GetZeroCopyOutput(predictor, &output);
52+
5153
PD_DestroyZeroCopyTensor(&output);
5254

5355
PD_DeleteAnalysisConfig(config);

go/demo/mobilenet_cxx.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ int main(int argc, char *argv[]) {
3535
input->copy_from_cpu(data.data());
3636
predictor->ZeroCopyRun();
3737
auto output_name = predictor->GetOutputNames()[0];
38-
output = predictor->GetOutputTensor(output_name);
38+
auto output = predictor->GetOutputTensor(output_name);
3939
return 0;
4040
}
4141

4242
void SetConfig(paddle::AnalysisConfig *config) {
4343
config->SetModel("data/model/__model__", "data/model/__params__");
44-
config->SwitchUseFeedFetchOps(true);
44+
config->SwitchUseFeedFetchOps(false);
4545
config->SwitchSpecifyInputNames(true);
4646
config->SwitchIrOptim(false);
4747
}

go/paddle/common.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ package paddle
2121
import "C"
2222
import "fmt"
2323

24+
type Precision C.Precision
25+
26+
const (
27+
kFloat32 Precision = C.kFloat32
28+
kInt8 Precision = C.kInt8
29+
kHalf Precision = C.kHalf
30+
)
31+
2432
func ConvertCBooleanToGo(b C.bool) bool {
2533
var c_false C.bool
2634
if b != c_false {

go/paddle/config.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,13 @@ func (config *AnalysisConfig) SetModel(model, params string) {
4343
//C.printString((*C.char)(unsafe.Pointer(&s[0])))
4444
c_model := C.CString(model)
4545
defer C.free(unsafe.Pointer(c_model))
46-
c_params := C.CString(params)
47-
defer C.free(unsafe.Pointer(c_params))
46+
var c_params *C.char
47+
if params == "" {
48+
c_params = nil
49+
} else {
50+
c_params = C.CString(params)
51+
defer C.free(unsafe.Pointer(c_params))
52+
}
4853

4954
C.PD_SetModel(config.c, c_model, c_params)
5055
}
@@ -61,8 +66,8 @@ func (config *AnalysisConfig) ParamsFile() string {
6166
return C.GoString(C.PD_ParamsFile(config.c))
6267
}
6368

64-
func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb uint64, device_id int) {
65-
C.PD_EnableUseGpu(config.c, C.ulong(memory_pool_init_size_mb), C.int(device_id))
69+
func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb int, device_id int) {
70+
C.PD_EnableUseGpu(config.c, C.int(memory_pool_init_size_mb), C.int(device_id))
6671
}
6772

6873
func (config *AnalysisConfig) DisableGpu() {
@@ -113,7 +118,9 @@ func (config *AnalysisConfig) SpecifyInputName() bool {
113118
return ConvertCBooleanToGo(C.PD_SpecifyInputName(config.c))
114119
}
115120

116-
//func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int)
121+
func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int, max_batch_size int, min_subgraph_size int, precision Precision, use_static bool, use_calib_mode bool) {
122+
C.PD_EnableTensorRtEngine(config.c, C.int(workspace_size), C.int(max_batch_size), C.int(min_subgraph_size), C.Precision(precision), C.bool(use_static), C.bool(use_calib_mode))
123+
}
117124

118125
func (config *AnalysisConfig) TensorrtEngineEnabled() bool {
119126
return ConvertCBooleanToGo(C.PD_TensorrtEngineEnabled(config.c))
@@ -175,15 +182,15 @@ func (config *AnalysisConfig) DisableGlogInfo() {
175182
}
176183

177184
func (config *AnalysisConfig) DeletePass(pass string) {
178-
c_pass := C.CString(pass)
179-
defer C.free(unsafe.Pointer(c_pass))
180-
C.PD_DeletePass(config.c, c_pass)
185+
c_pass := C.CString(pass)
186+
defer C.free(unsafe.Pointer(c_pass))
187+
C.PD_DeletePass(config.c, c_pass)
181188
}
182189

183190
func (config *AnalysisConfig) SetInValid() {
184-
C.PD_SetInValid(config.c)
191+
C.PD_SetInValid(config.c)
185192
}
186193

187194
func (config *AnalysisConfig) IsValid() bool {
188-
return ConvertCBooleanToGo(C.PD_IsValid(config.c))
195+
return ConvertCBooleanToGo(C.PD_IsValid(config.c))
189196
}

paddle/fluid/inference/capi/paddle_c_api.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ PADDLE_CAPI_EXPORT extern const char* PD_ProgFile(
161161
PADDLE_CAPI_EXPORT extern const char* PD_ParamsFile(
162162
const PD_AnalysisConfig* config);
163163

164-
PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu(
165-
PD_AnalysisConfig* config, uint64_t memory_pool_init_size_mb,
166-
int device_id);
164+
PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu(PD_AnalysisConfig* config,
165+
int memory_pool_init_size_mb,
166+
int device_id);
167167

168168
PADDLE_CAPI_EXPORT extern void PD_DisableGpu(PD_AnalysisConfig* config);
169169

paddle/fluid/inference/capi/pd_config.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ const char* PD_ParamsFile(const PD_AnalysisConfig* config) {
7979
return config->config.params_file().c_str();
8080
}
8181

82-
void PD_EnableUseGpu(PD_AnalysisConfig* config,
83-
uint64_t memory_pool_init_size_mb, int device_id) {
82+
void PD_EnableUseGpu(PD_AnalysisConfig* config, int memory_pool_init_size_mb,
83+
int device_id) {
8484
PADDLE_ENFORCE_NOT_NULL(config);
85-
config->config.EnableUseGpu(memory_pool_init_size_mb, device_id);
85+
config->config.EnableUseGpu(static_cast<uint64_t>(memory_pool_init_size_mb),
86+
device_id);
8687
}
8788

8889
void PD_DisableGpu(PD_AnalysisConfig* config) {

0 commit comments

Comments
 (0)