Skip to content

Commit 4273b35

Browse files
authored
Merge pull request #4473 from NHZlX/fix_merge_model
refine paddle_merge_model
2 parents aa3de35 + 01cd69d commit 4273b35

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

paddle/capi/gradient_machine.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
6464
modelConfigProtobuf.resize(modelConfigSize);
6565
is.read(&modelConfigProtobuf[0], modelConfigSize);
6666
paddle::TrainerConfig config;
67+
paddle::ModelConfig modelConfig;
6768
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
68-
return kPD_PROTOBUF_ERROR;
69+
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
70+
!modelConfig.IsInitialized()) {
71+
return kPD_PROTOBUF_ERROR;
72+
}
73+
} else {
74+
modelConfig = config.model_config();
6975
}
7076
auto ptr = new paddle::capi::CGradientMachine();
7177
ptr->machine.reset(paddle::GradientMachine::create(
72-
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
78+
modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
7379
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
7480
for (auto& para : parameters) {
7581
para->load(is);

paddle/trainer/MergeModel.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/utils/PythonUtil.h"
2121

2222
DEFINE_string(model_dir, "", "Directory for separated model files");
23+
DEFINE_string(config_file, "", "Config file for the model");
2324
DEFINE_string(model_file, "", "File for merged model file");
2425

2526
using namespace paddle; // NOLINT
@@ -28,7 +29,8 @@ using namespace std; // NOLINT
2829
int main(int argc, char** argv) {
2930
initMain(argc, argv);
3031
initPython(argc, argv);
31-
string confFile = TrainerConfigHelper::getConfigNameFromPath(FLAGS_model_dir);
32+
33+
string confFile = FLAGS_config_file;
3234
#ifndef PADDLE_WITH_CUDA
3335
FLAGS_use_gpu = false;
3436
#endif

proto/TrainerConfig.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import "ModelConfig.proto";
1919
package paddle;
2020

2121
message OptimizationConfig {
22-
required int32 batch_size = 3;
22+
optional int32 batch_size = 3 [ default = 1 ];
2323
required string algorithm = 4 [ default = "async_sgd" ];
2424
optional int32 num_batches_per_send_parameter = 5 [ default = 1 ];
2525
optional int32 num_batches_per_get_parameter = 6 [ default = 1 ];

0 commit comments

Comments
 (0)