Skip to content

Commit f8a529c

Browse files
authored
Merge pull request #1051 from jacquesqiao/add-pserver-util
Add ParameterServerController for parameter server python api
2 parents 77eb729 + aa9f516 commit f8a529c

File tree

8 files changed

+244
-108
lines changed

8 files changed

+244
-108
lines changed

demo/quick_start/cluster/cluster_train.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ log_file="$bin_dir/train.log"
2525
pushd "$home_dir"
2626
cfg=trainer_config.lr.py
2727
paddle train \
28+
--start_pserver=false \
2829
--config=$cfg \
2930
--save_dir=${model_dir} \
3031
--trainer_count=4 \

paddle/pserver/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ set(PSERVER_SOURCES
2424
BaseClient.cpp
2525
ParameterClient2.cpp
2626
ParameterServer2.cpp
27-
SparseParameterDistribution.cpp)
27+
SparseParameterDistribution.cpp
28+
ParameterServerController.cpp)
2829

2930
set(PSERVER_HEADERS
3031
BaseClient.h
3132
ParameterClient2.h
3233
ParameterServer2.h
33-
SparseParameterDistribution.h)
34+
SparseParameterDistribution.h
35+
ParameterServerController.h)
3436

3537
add_library(paddle_pserver STATIC
3638
${PSERVER_SOURCES})

paddle/pserver/ParameterServer2Main.cpp

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,66 +13,17 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <fstream>
16-
#include "paddle/utils/StringUtil.h"
17-
#include "paddle/utils/Util.h"
18-
19-
#include "ParameterServer2.h"
20-
#include "RDMANetwork.h"
21-
#include "paddle/utils/Flags.h"
16+
#include "ParameterServerController.h"
2217

2318
using namespace paddle; // NOLINT
2419

2520
int main(int argc, char** argv) {
2621
initMain(argc, argv);
2722

28-
std::vector<std::string> devices;
29-
std::vector<std::shared_ptr<ParameterServer2>> pservers;
30-
31-
// round robin to loadbalance RDMA server ENGINE
32-
int rdmaCpu = 0;
33-
int onlineCpus = rdma::numCpus();
34-
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
35-
if (FLAGS_nics.empty()) {
36-
pservers.resize(numPorts);
37-
for (int i = 0; i < numPorts; ++i) {
38-
if (FLAGS_rdma_tcp == "rdma") {
39-
pservers[i].reset(
40-
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
41-
rdmaCpu = rdmaCpu % onlineCpus;
42-
} else {
43-
pservers[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
44-
}
45-
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
46-
<< FLAGS_port + i;
47-
LOG(INFO) << "pserver started : " << FLAGS_port + i;
48-
pservers[i]->start();
49-
}
50-
} else {
51-
str::split(FLAGS_nics, ',', &devices);
52-
pservers.resize(devices.size() * numPorts);
53-
for (int i = 0; i < numPorts; ++i) {
54-
for (size_t j = 0; j < devices.size(); ++j) {
55-
if (FLAGS_rdma_tcp == "rdma") {
56-
pservers[i * devices.size() + j].reset(new ParameterServer2(
57-
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
58-
rdmaCpu = rdmaCpu % onlineCpus;
59-
} else {
60-
pservers[i * devices.size() + j].reset(
61-
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
62-
}
63-
CHECK(pservers[i * devices.size() + j]->init())
64-
<< "Fail to initialize parameter server" << devices[j]
65-
<< FLAGS_port + i;
66-
LOG(INFO) << "pserver started : " << devices[j] << ":"
67-
<< FLAGS_port + i;
68-
pservers[i * devices.size() + j]->start();
69-
}
70-
}
71-
}
72-
73-
for (auto& pserver : pservers) {
74-
pserver->join();
75-
}
23+
std::unique_ptr<ParameterServerController> parameterServerPtr(
24+
paddle::ParameterServerController::createFromGflags());
25+
parameterServerPtr->start();
26+
parameterServerPtr->wait();
7627

7728
return 0;
7829
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "ParameterServerController.h"
16+
17+
namespace paddle {
18+
19+
ParameterServerController::ParameterServerController(
20+
const ParameterServerConfig& config) {
21+
// round robin to load balance RDMA server ENGINE
22+
std::vector<std::string> devices;
23+
int rdmaCpu = 0;
24+
int onlineCpus = rdma::numCpus();
25+
int numPorts = config.ports_num() + config.ports_num_for_sparse();
26+
27+
if (config.nics().empty()) {
28+
parameterServers_.resize(numPorts);
29+
for (int i = 0; i < numPorts; ++i) {
30+
if (config.rdma_tcp() == "rdma") {
31+
parameterServers_[i].reset(
32+
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++));
33+
rdmaCpu = rdmaCpu % onlineCpus;
34+
} else {
35+
parameterServers_[i].reset(
36+
new ParameterServer2(std::string(), config.port() + i));
37+
}
38+
CHECK(parameterServers_[i]->init()) << "Fail to initialize parameter "
39+
"server on port "
40+
<< config.port() + i;
41+
}
42+
} else {
43+
str::split(config.nics(), ',', &devices);
44+
parameterServers_.resize(devices.size() * numPorts);
45+
for (int i = 0; i < numPorts; ++i) {
46+
for (size_t j = 0; j < devices.size(); ++j) {
47+
if (config.rdma_tcp() == "rdma") {
48+
parameterServers_[i * devices.size() + j].reset(new ParameterServer2(
49+
getIpAddr(devices[j]), config.port() + i, rdmaCpu++));
50+
rdmaCpu = rdmaCpu % onlineCpus;
51+
} else {
52+
parameterServers_[i * devices.size() + j].reset(
53+
new ParameterServer2(getIpAddr(devices[j]), config.port() + i));
54+
}
55+
CHECK(parameterServers_[i * devices.size() + j]->init())
56+
<< "Fail to initialize parameter server with device " << devices[j]
57+
<< config.port() + i;
58+
}
59+
}
60+
}
61+
}
62+
63+
ParameterServerController::~ParameterServerController() { this->wait(); }
64+
65+
ParameterServerController* ParameterServerController::createFromGflags() {
66+
ParameterServerConfig config;
67+
68+
config.set_nics(FLAGS_nics);
69+
config.set_rdma_tcp(FLAGS_rdma_tcp);
70+
config.set_port(FLAGS_port);
71+
config.set_ports_num(FLAGS_ports_num);
72+
config.set_ports_num_for_sparse(FLAGS_ports_num_for_sparse);
73+
74+
return create(config);
75+
}
76+
77+
ParameterServerController* ParameterServerController::create(
78+
const ParameterServerConfig& config) {
79+
return new ParameterServerController(config);
80+
}
81+
82+
void ParameterServerController::start() {
83+
LOG(INFO) << "number of parameterServer instances: "
84+
<< parameterServers_.size();
85+
int i = 0;
86+
for (const auto& parameterServer : parameterServers_) {
87+
LOG(INFO) << "Starting parameterServer[" << i << "]";
88+
parameterServer->start();
89+
i++;
90+
}
91+
}
92+
93+
void ParameterServerController::wait() {
94+
int i = 0;
95+
for (const auto& parameterServer : parameterServers_) {
96+
LOG(INFO) << "Waiting parameterServer[" << i << "]";
97+
parameterServer->join();
98+
i++;
99+
}
100+
}
101+
102+
} // namespace paddle
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "ParameterServer2.h"
18+
#include "ParameterServerConfig.pb.h"
19+
#include "RDMANetwork.h"
20+
#include "paddle/utils/StringUtil.h"
21+
22+
namespace paddle {
23+
24+
/**
25+
* @brief ParameterServerController is used for create, init and manage multi
26+
* parameter server instances. The num of the instances is decided by port
27+
* num(the ports number for parameter send) and network devices configured
28+
* by gflags or proto.
29+
*/
30+
class ParameterServerController final {
31+
public:
32+
DISABLE_COPY(ParameterServerController);
33+
34+
/**
35+
* @brief Ctor, Create a ParameterServerController from ParameterServerConfig.
36+
*/
37+
explicit ParameterServerController(const ParameterServerConfig& config);
38+
39+
/**
40+
* @brief Dtor.
41+
*/
42+
~ParameterServerController();
43+
44+
/**
45+
* @brief create ParameterServerController from gflags, this is used for
46+
* compatibility with the old usage of configuration by gflags.
47+
*/
48+
static ParameterServerController* createFromGflags();
49+
50+
/**
51+
* @brief create ParameterServerController with ParameterServerConfig, remove
52+
* gflags from ParameterServer. Init all ParameterServer2 instances according
53+
* to
54+
* the config.
55+
*/
56+
static ParameterServerController* create(const ParameterServerConfig& config);
57+
58+
/**
59+
* @brief start all ParameterServer2 instances in this
60+
* ParameterServerController.
61+
*/
62+
void start();
63+
64+
/**
65+
* @brief join and wait for all ParameterServer2 instances thread in this
66+
* ParameterServerController.
67+
*/
68+
void wait();
69+
70+
private:
71+
std::vector<std::unique_ptr<ParameterServer2>> parameterServers_;
72+
};
73+
74+
} // namespace paddle

paddle/trainer/TrainerMain.cpp

Lines changed: 6 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/pserver/ParameterServer2.h"
16-
#include "paddle/utils/Common.h"
15+
#include <fenv.h>
16+
#include "paddle/pserver/ParameterServerController.h"
1717
#include "paddle/utils/PythonUtil.h"
18-
#include "paddle/utils/StringUtil.h"
1918

2019
#include "ParamUtil.h"
2120
#include "Trainer.h"
22-
#include "paddle/pserver/RDMANetwork.h"
2321

2422
DEFINE_bool(start_pserver, false, "Whether to start pserver");
2523
DECLARE_int32(gpu_id);
@@ -38,54 +36,11 @@ int main(int argc, char** argv) {
3836
initMain(argc, argv);
3937
initPython(argc, argv);
4038

41-
std::vector<std::unique_ptr<ParameterServer2>> pservers;
42-
std::vector<std::string> devices;
43-
39+
std::unique_ptr<ParameterServerController> parameterServerPtr(nullptr);
4440
if (FLAGS_start_pserver) {
45-
// round robin to loadbalance RDMA server ENGINE
46-
int rdmaCpu = 0;
47-
int onlineCpus = rdma::numCpus();
48-
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
49-
if (FLAGS_nics.empty()) {
50-
pservers.resize(numPorts);
51-
for (int i = 0; i < numPorts; ++i) {
52-
if (FLAGS_rdma_tcp == "rdma") {
53-
pservers[i].reset(
54-
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
55-
rdmaCpu = rdmaCpu % onlineCpus;
56-
} else {
57-
pservers[i].reset(
58-
new ParameterServer2(std::string(), FLAGS_port + i));
59-
}
60-
61-
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
62-
<< FLAGS_port + i;
63-
LOG(INFO) << "pserver started : " << FLAGS_port + i;
64-
pservers[i]->start();
65-
}
66-
} else {
67-
str::split(FLAGS_nics, ',', &devices);
68-
pservers.resize(devices.size() * numPorts);
69-
for (int i = 0; i < numPorts; ++i) {
70-
for (size_t j = 0; j < devices.size(); ++j) {
71-
if (FLAGS_rdma_tcp == "rdma") {
72-
pservers[i * devices.size() + j].reset(new ParameterServer2(
73-
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
74-
rdmaCpu = rdmaCpu % onlineCpus;
75-
} else {
76-
pservers[i * devices.size() + j].reset(
77-
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
78-
}
79-
80-
CHECK(pservers[i * devices.size() + j]->init())
81-
<< "Fail to initialize parameter server" << devices[j]
82-
<< FLAGS_port + i;
83-
LOG(INFO) << "pserver started : " << devices[j] << ":"
84-
<< FLAGS_port + i;
85-
pservers[i * devices.size() + j]->start();
86-
}
87-
}
88-
}
41+
parameterServerPtr.reset(
42+
paddle::ParameterServerController::createFromGflags());
43+
parameterServerPtr->start();
8944
}
9045
Trainer trainer;
9146
auto config = TrainerConfigHelper::createFromFlags();

proto/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ set(proto_filenames
44
ModelConfig.proto
55
ParameterConfig.proto
66
ParameterService.proto
7-
TrainerConfig.proto)
7+
TrainerConfig.proto
8+
ParameterServerConfig.proto)
89

910
set(PROTO_GEN)
1011
set(PROTO_GEN_PY)

0 commit comments

Comments
 (0)