@@ -27,22 +27,76 @@ extern "C" {
27
27
int PDGradientMachineCreateForPredict (PD_GradiemtMachine* machine,
28
28
void * modelConfigProtobuf,
29
29
int size) {
30
- if (modelConfigProtobuf == nullptr ) return PD_NULLPTR ;
30
+ if (modelConfigProtobuf == nullptr ) return kPD_NULLPTR ;
31
31
paddle::ModelConfig config;
32
32
if (!config.ParseFromArray (modelConfigProtobuf, size) ||
33
33
!config.IsInitialized ()) {
34
- return PD_PROTOBUF_ERROR ;
34
+ return kPD_PROTOBUF_ERROR ;
35
35
}
36
36
37
37
auto ptr = new paddle::capi::CGradientMachine ();
38
38
ptr->machine .reset (paddle::GradientMachine::create (
39
39
config, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
40
40
*machine = ptr;
41
- return PD_NO_ERROR ;
41
+ return kPD_NO_ERROR ;
42
42
}
43
43
44
44
int PDGradientMachineDestroy (PD_GradiemtMachine machine) {
45
45
delete cast (machine);
46
- return PD_NO_ERROR;
46
+ return kPD_NO_ERROR ;
47
+ }
48
+
49
+ int PDGradientMachineLoadParameterFromDisk (PD_GradiemtMachine machine,
50
+ const char * path) {
51
+ auto m = cast (machine);
52
+ if (m == nullptr || path == nullptr || m->machine == nullptr )
53
+ return kPD_NULLPTR ;
54
+ m->machine ->loadParameters (path);
55
+ return kPD_NO_ERROR ;
56
+ }
57
+
58
+ int PDGradientMachineForward (PD_GradiemtMachine machine,
59
+ PD_Arguments inArgs,
60
+ PD_Arguments outArgs,
61
+ bool isTrain) {
62
+ auto m = cast (machine);
63
+ auto in = paddle::capi::cast<paddle::capi::CArguments>(inArgs);
64
+ auto out = paddle::capi::cast<paddle::capi::CArguments>(outArgs);
65
+ if (m == nullptr || in == nullptr || out == nullptr || m->machine == nullptr )
66
+ return kPD_NULLPTR ;
67
+ m->machine ->forward (
68
+ in->args , &out->args , isTrain ? paddle::PASS_TRAIN : paddle::PASS_TEST);
69
+ return kPD_NO_ERROR ;
70
+ }
71
+
72
+ int PDGradientMachineCreateSharedParam (PD_GradiemtMachine origin,
73
+ void * modelConfigProtobuf,
74
+ int size,
75
+ PD_GradiemtMachine* slave) {
76
+ auto o = cast (origin);
77
+ if (origin == nullptr || slave == nullptr || o->machine == nullptr ) {
78
+ return kPD_NULLPTR ;
79
+ }
80
+ paddle::ModelConfig config;
81
+ if (!config.ParseFromArray (modelConfigProtobuf, size) ||
82
+ !config.IsInitialized ()) {
83
+ return kPD_PROTOBUF_ERROR ;
84
+ }
85
+
86
+ std::unique_ptr<paddle::capi::CGradientMachine> ptr (
87
+ new paddle::capi::CGradientMachine ());
88
+ auto nn = paddle::NeuralNetwork::create (config);
89
+ nn->init (config,
90
+ [&o](int paramId, paddle::Parameter* param) {
91
+ auto p = o->machine ->getParameters ()[paramId];
92
+ param->enableSharedType (paddle::PARAMETER_VALUE,
93
+ p->getBuf (paddle::PARAMETER_VALUE));
94
+
95
+ },
96
+ {paddle::PARAMETER_VALUE},
97
+ false );
98
+ ptr->machine .reset (nn);
99
+ *slave = ptr.release ();
100
+ return kPD_NO_ERROR ;
47
101
}
48
102
}
0 commit comments