diff --git a/demo/quick_start/cluster/pserver.py b/demo/quick_start/cluster/pserver.py new file mode 100644 index 00000000000000..b36f3749877e01 --- /dev/null +++ b/demo/quick_start/cluster/pserver.py @@ -0,0 +1,32 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from py_paddle import swig_paddle as api +import paddle.proto.ParameterServerConfig_pb2 as ParameterServerConfig + + +def main(): + api.initPaddle() + pServerConfig = ParameterServerConfig.ParameterServerConfig() + pServerConfig.ports_num = 1 + pServerConfig.nics = "lo0" + pServerConfig.num_gradient_servers = 1 + pServerConfig.port = 7164 + pserver = api.ParameterServer.createFromConfigProto(pServerConfig) + pserver.start() + pserver.wait() + + +if __name__ == '__main__': + main() diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 6e8fcd114df580..bf46aea31f1db3 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -28,7 +28,8 @@ set(API_SOURCES SequenceGenerator.cpp Trainer.cpp Util.cpp - Vector.cpp) + Vector.cpp + ParameterServer.cpp) set(API_HEADER PaddleAPI.h Internal.h) diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 068ba286c07d88..ca5d4fcefaf12b 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -179,6 +179,8 @@ namespace std { %newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterUpdater::createLocalUpdater; %newobject ParameterUpdater::createRemoteUpdater; +%newobject ParameterServer::createByConfigProtoPtr; +%newobject ParameterServer::createByConfigProtoStr; %feature("director") UpdateCallback; %feature("autodoc", 1); // To generate method stub, for code hint in ide @@ -197,5 +199,6 @@ namespace std { %ignore ParameterConfigPrivate; %ignore OptimizationConfigPrivate; %ignore ParameterTraverseCallbackPrivate; +%ignore ParameterServerPrivate; %include "utils/GlobalConstants.h" %include "api/PaddleAPI.h" diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index f5af8b0035b44d..41c429490ea333 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -874,6 +874,28 @@ class ParameterUpdater { ParameterUpdaterPrivate* m; }; +struct ParameterServerPrivate; +class ParameterServer { +private: + ParameterServer(); + +public: + static ParameterServer* createByConfigProtoPtr(const void* confPtr); + static ParameterServer* createByConfigProtoStr(const std::string& protoStr); + + ~ParameterServer(); + + /** + * @brief initialize Parameter Server. + * @param gm + */ + void start(); + void wait(); + +private: + ParameterServerPrivate* m; +}; + struct EvaluatorPrivate; class Evaluator { private: diff --git a/paddle/api/PaddleAPIPrivate.h b/paddle/api/PaddleAPIPrivate.h index f41352bfec7c33..bb1a669fa19676 100644 --- a/paddle/api/PaddleAPIPrivate.h +++ b/paddle/api/PaddleAPIPrivate.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/gserver/evaluators/Evaluator.h" #include "paddle/gserver/gradientmachines/GradientMachine.h" #include "paddle/parameter/ParameterUpdaterBase.h" +#include "paddle/pserver/ParameterServerController.h" #include "paddle/trainer/TrainerConfigHelper.h" struct GradientMachinePrivate { @@ -72,6 +73,10 @@ struct ParameterUpdaterPrivate { std::unique_ptr updater; }; +struct ParameterServerPrivate { + std::unique_ptr parameterServerController; +}; + struct ParameterPrivate { std::shared_ptr sharedPtr; paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater, diff --git a/paddle/api/ParameterServer.cpp b/paddle/api/ParameterServer.cpp new file mode 100644 index 00000000000000..8ef01ce538417a --- /dev/null +++ b/paddle/api/ParameterServer.cpp @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "PaddleAPI.h" + +#include "PaddleAPIPrivate.h" + +ParameterServer::ParameterServer() : m(new ParameterServerPrivate()) {} + +ParameterServer* ParameterServer::createByConfigProtoPtr(const void* confPtr) { + auto& conf = *(const paddle::ParameterServerConfig*)(confPtr); + auto pServer = new ParameterServer(); + pServer->m->parameterServerController.reset( + paddle::ParameterServerController::create(conf)); + return pServer; +} + +ParameterServer* ParameterServer::createByConfigProtoStr( + const std::string& protoStr) { + paddle::ParameterServerConfig conf; + conf.ParseFromString(protoStr); + if (conf.IsInitialized()) { + return ParameterServer::createByConfigProtoPtr(&conf); + } else { + return nullptr; + } +} + +ParameterServer::~ParameterServer() { delete m; } + +void ParameterServer::start() { m->parameterServerController->start(); } + +void ParameterServer::wait() { m->parameterServerController->wait(); } diff --git a/paddle/py_paddle/util.py b/paddle/py_paddle/util.py index ce105d249aaf3e..1e483fb7c41140 100644 --- a/paddle/py_paddle/util.py +++ b/paddle/py_paddle/util.py @@ -15,18 +15,20 @@ Some Useful method for py_paddle. """ -import swig_paddle import os -import paddle.trainer.PyDataProviderWrapper -import paddle.proto.ParameterConfig_pb2 -import paddle.proto.ModelConfig_pb2 -import paddle.proto.TrainerConfig_pb2 import weakref import numpy import struct import sys import copy +import swig_paddle +import paddle.trainer.PyDataProviderWrapper +import paddle.proto.ParameterConfig_pb2 +import paddle.proto.ModelConfig_pb2 +import paddle.proto.TrainerConfig_pb2 +import paddle.proto.ParameterServerConfig_pb2 + def initializePaddle(*args): """ @@ -558,11 +560,29 @@ def getForwardOutput(self): swig_paddle.Trainer.getForwardOutput = getForwardOutput +def __monkeypatch_parameter_server__(): + def createFromConfigProto(protoObj): + """ + Create Parameter Server From Proto object. + :param protoObj: ParameterServer Config + :type protoObj: proto.ParameterServerConfig_pb2.ParameterServerConfig + :return: paddle.ParameterServer + """ + assert isinstance( + protoObj, + paddle.proto.ParameterServerConfig_pb2.ParameterServerConfig) + return swig_paddle.ParameterServer.createByConfigProtoStr( + protoObj.SerializeToString()) + + swig_paddle.ParameterServer.createFromConfigProto = \ + staticmethod(createFromConfigProto) + + def monkeypatches(): patches = [ __monkeypatch_init_paddle__, __monkeypatch_gradient_machine__, __monkey_patch_protobuf_objects__, __monkey_patch_parameter__, - __monkey_patch_trainer__ + __monkey_patch_trainer__, __monkeypatch_parameter_server__ ] for patch in patches: patch() diff --git a/proto/ParameterServerConfig.proto b/proto/ParameterServerConfig.proto index 3068bba8b10d89..d5772d33077d52 100644 --- a/proto/ParameterServerConfig.proto +++ b/proto/ParameterServerConfig.proto @@ -29,22 +29,22 @@ message ParameterClientConfig { message ParameterServerConfig { // The ports number for parameter send, // increment based on default port number - required int32 ports_num = 1 [default = 1]; + optional int32 ports_num = 1 [default = 1]; // The ports number for parameter send, // increment based on default (port + ports_num - required int32 ports_num_for_sparse = 2 [default = 0]; + optional int32 ports_num_for_sparse = 2 [default = 0]; // network device name for pservers - required string nics = 3 [default = "xgbe0,xgbe1"]; - required string rdma_tcp = 4 [default = "tcp"]; + optional string nics = 3 [default = "xgbe0,xgbe1"]; + optional string rdma_tcp = 4 [default = "tcp"]; // Listening port for pserver - required int32 port = 5 [default = 20134]; + optional int32 port = 5 [default = 20134]; // number of gradient servers - required int32 num_gradient_servers = 6 [default = 1]; + optional int32 num_gradient_servers = 6 [default = 1]; // number of threads for sync op exec - required int32 pserver_num_threads = 7 [default = 1]; + optional int32 pserver_num_threads = 7 [default = 1]; // control config_.async_lagged_grad_discard_ratio() min value - required double async_lagged_ratio_min = 8 [default = 1.0]; + optional double async_lagged_ratio_min = 8 [default = 1.0]; // if async_lagged_grad_discard_ratio is not set in trainer_config.conf // use it as defalut value - required double async_lagged_ratio_default = 9 [default = 1.5]; + optional double async_lagged_ratio_default = 9 [default = 1.5]; } \ No newline at end of file