File tree Expand file tree Collapse file tree 3 files changed +12
-4
lines changed Expand file tree Collapse file tree 3 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
64
64
modelConfigProtobuf.resize (modelConfigSize);
65
65
is.read (&modelConfigProtobuf[0 ], modelConfigSize);
66
66
paddle::TrainerConfig config;
67
+ paddle::ModelConfig modelConfig;
67
68
if (!config.ParseFromString (modelConfigProtobuf) || !config.IsInitialized ()) {
68
- return kPD_PROTOBUF_ERROR ;
69
+ if (!modelConfig.ParseFromString (modelConfigProtobuf) ||
70
+ !modelConfig.IsInitialized ()) {
71
+ return kPD_PROTOBUF_ERROR ;
72
+ }
73
+ } else {
74
+ modelConfig = config.model_config ();
69
75
}
70
76
auto ptr = new paddle::capi::CGradientMachine ();
71
77
ptr->machine .reset (paddle::GradientMachine::create (
72
- config. model_config () , CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
78
+ modelConfig , CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
73
79
std::vector<paddle::ParameterPtr>& parameters = ptr->machine ->getParameters ();
74
80
for (auto & para : parameters) {
75
81
para->load (is);
Original file line number Diff line number Diff line change @@ -20,6 +20,7 @@ limitations under the License. */
20
20
#include " paddle/utils/PythonUtil.h"
21
21
22
22
DEFINE_string (model_dir, " " , " Directory for separated model files" );
23
+ DEFINE_string (config_file, " " , " Config file for the model" );
23
24
DEFINE_string (model_file, " " , " File for merged model file" );
24
25
25
26
using namespace paddle ; // NOLINT
@@ -28,7 +29,8 @@ using namespace std; // NOLINT
28
29
int main (int argc, char ** argv) {
29
30
initMain (argc, argv);
30
31
initPython (argc, argv);
31
- string confFile = TrainerConfigHelper::getConfigNameFromPath (FLAGS_model_dir);
32
+
33
+ string confFile = FLAGS_config_file;
32
34
#ifndef PADDLE_WITH_CUDA
33
35
FLAGS_use_gpu = false ;
34
36
#endif
Original file line number Diff line number Diff line change @@ -19,7 +19,7 @@ import "ModelConfig.proto";
19
19
package paddle ;
20
20
21
21
message OptimizationConfig {
22
- required int32 batch_size = 3 ;
22
+ optional int32 batch_size = 3 [ default = 1 ] ;
23
23
required string algorithm = 4 [ default = "async_sgd" ];
24
24
optional int32 num_batches_per_send_parameter = 5 [ default = 1 ];
25
25
optional int32 num_batches_per_get_parameter = 6 [ default = 1 ];
You can’t perform that action at this time.
0 commit comments