Skip to content

Commit f3c61cb

Browse files
committed
add pserver util and parameter server config
1 parent d52ebb0 commit f3c61cb

File tree

9 files changed

+199
-108
lines changed

9 files changed

+199
-108
lines changed

demo/quick_start/cluster/cluster_train.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ log_file="$bin_dir/train.log"
2525
pushd "$home_dir"
2626
cfg=trainer_config.lr.py
2727
paddle train \
28+
--start_pserver=false \
2829
--config=$cfg \
2930
--save_dir=${model_dir} \
3031
--trainer_count=4 \

demo/quick_start/cluster/pserver.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ source "$bin_dir/env.sh"
1919
paddle pserver \
2020
--nics=`get_nics` \
2121
--port=7164 \
22-
--ports_num=1 \
22+
--ports_num=2 \
2323
--ports_num_for_sparse=1 \
2424
--num_gradient_servers=1 \
2525
--comment="paddle_pserver" \

paddle/pserver/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ set(PSERVER_SOURCES
2424
BaseClient.cpp
2525
ParameterClient2.cpp
2626
ParameterServer2.cpp
27-
SparseParameterDistribution.cpp)
27+
SparseParameterDistribution.cpp
28+
PServerUtil.cpp)
2829

2930
set(PSERVER_HEADERS
3031
BaseClient.h
3132
ParameterClient2.h
3233
ParameterServer2.h
33-
SparseParameterDistribution.h)
34+
SparseParameterDistribution.h
35+
PServerUtil.h)
3436

3537
add_library(paddle_pserver STATIC
3638
${PSERVER_SOURCES})

paddle/pserver/PServerUtil.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "PServerUtil.h"
16+
17+
namespace paddle {
18+
19+
ParameterServerConfig* PServerUtil::initConfig() {
20+
ParameterServerConfig* config = new ParameterServerConfig();
21+
config->set_nics(FLAGS_nics);
22+
config->set_port(FLAGS_port);
23+
config->set_ports_num(FLAGS_ports_num);
24+
config->set_rdma_tcp(FLAGS_rdma_tcp);
25+
return config;
26+
}
27+
28+
PServerUtil* PServerUtil::create() {
29+
auto& pServerConfig = *paddle::PServerUtil::initConfig();
30+
return PServerUtil::create(pServerConfig);
31+
}
32+
33+
PServerUtil* PServerUtil::create(const ParameterServerConfig& config) {
34+
return new PServerUtil(config);
35+
}
36+
37+
PServerUtil::PServerUtil(const ParameterServerConfig& config) {
38+
// round robin to load balance RDMA server ENGINE
39+
std::vector<std::string> devices;
40+
int rdmaCpu = 0;
41+
int onlineCpus = rdma::numCpus();
42+
;
43+
int numPorts = config.ports_num() + config.ports_num_for_sparse();
44+
45+
if (FLAGS_nics.empty()) {
46+
pservers_.resize(numPorts);
47+
for (int i = 0; i < numPorts; ++i) {
48+
if (FLAGS_rdma_tcp == "rdma") {
49+
pservers_[i].reset(
50+
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
51+
rdmaCpu = rdmaCpu % onlineCpus;
52+
} else {
53+
pservers_[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
54+
}
55+
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server"
56+
<< FLAGS_port + i;
57+
}
58+
} else {
59+
str::split(FLAGS_nics, ',', &devices);
60+
pservers_.resize(devices.size() * numPorts);
61+
for (int i = 0; i < numPorts; ++i) {
62+
for (size_t j = 0; j < devices.size(); ++j) {
63+
if (FLAGS_rdma_tcp == "rdma") {
64+
pservers_[i * devices.size() + j].reset(new ParameterServer2(
65+
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
66+
rdmaCpu = rdmaCpu % onlineCpus;
67+
} else {
68+
pservers_[i * devices.size() + j].reset(
69+
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
70+
}
71+
CHECK(pservers_[i * devices.size() + j]->init())
72+
<< "Fail to initialize parameter server" << devices[j]
73+
<< FLAGS_port + i;
74+
}
75+
}
76+
}
77+
}
78+
79+
PServerUtil::~PServerUtil() { this->join(); }
80+
81+
void PServerUtil::start() {
82+
LOG(INFO) << "pserver sizes : " << pservers_.size();
83+
int i = 0;
84+
for (const auto& pserver : pservers_) {
85+
LOG(INFO) << "pserver started : " << i;
86+
pserver->start();
87+
i++;
88+
}
89+
}
90+
91+
void PServerUtil::join() {
92+
LOG(INFO) << "pserver sizes : " << pservers_.size();
93+
int i = 0;
94+
for (const auto& pserver : pservers_) {
95+
LOG(INFO) << "pserver join : " << i;
96+
pserver->join();
97+
i++;
98+
}
99+
}
100+
101+
} // namespace paddle

paddle/pserver/PServerUtil.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "ParameterServer2.h"
18+
#include "ParameterServerConfig.pb.h"
19+
#include "RDMANetwork.h"
20+
#include "paddle/utils/StringUtil.h"
21+
22+
namespace paddle {
23+
24+
class PServerUtil {
25+
public:
26+
DISABLE_COPY(PServerUtil);
27+
static PServerUtil* create();
28+
static PServerUtil* create(const ParameterServerConfig& config);
29+
explicit PServerUtil(const ParameterServerConfig& config);
30+
~PServerUtil();
31+
static ParameterServerConfig* initConfig();
32+
void start();
33+
void join();
34+
35+
private:
36+
std::vector<std::shared_ptr<ParameterServer2>> pservers_;
37+
};
38+
39+
} // namespace paddle

paddle/pserver/ParameterServer2Main.cpp

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,66 +13,17 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <fstream>
16-
#include "paddle/utils/StringUtil.h"
17-
#include "paddle/utils/Util.h"
18-
19-
#include "ParameterServer2.h"
20-
#include "RDMANetwork.h"
21-
#include "paddle/utils/Flags.h"
16+
#include "PServerUtil.h"
17+
#include "paddle/trainer/ParamUtil.h"
2218

2319
using namespace paddle; // NOLINT
2420

2521
int main(int argc, char** argv) {
2622
initMain(argc, argv);
2723

28-
std::vector<std::string> devices;
29-
std::vector<std::shared_ptr<ParameterServer2>> pservers;
30-
31-
// round robin to loadbalance RDMA server ENGINE
32-
int rdmaCpu = 0;
33-
int onlineCpus = rdma::numCpus();
34-
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
35-
if (FLAGS_nics.empty()) {
36-
pservers.resize(numPorts);
37-
for (int i = 0; i < numPorts; ++i) {
38-
if (FLAGS_rdma_tcp == "rdma") {
39-
pservers[i].reset(
40-
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
41-
rdmaCpu = rdmaCpu % onlineCpus;
42-
} else {
43-
pservers[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
44-
}
45-
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
46-
<< FLAGS_port + i;
47-
LOG(INFO) << "pserver started : " << FLAGS_port + i;
48-
pservers[i]->start();
49-
}
50-
} else {
51-
str::split(FLAGS_nics, ',', &devices);
52-
pservers.resize(devices.size() * numPorts);
53-
for (int i = 0; i < numPorts; ++i) {
54-
for (size_t j = 0; j < devices.size(); ++j) {
55-
if (FLAGS_rdma_tcp == "rdma") {
56-
pservers[i * devices.size() + j].reset(new ParameterServer2(
57-
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
58-
rdmaCpu = rdmaCpu % onlineCpus;
59-
} else {
60-
pservers[i * devices.size() + j].reset(
61-
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
62-
}
63-
CHECK(pservers[i * devices.size() + j]->init())
64-
<< "Fail to initialize parameter server" << devices[j]
65-
<< FLAGS_port + i;
66-
LOG(INFO) << "pserver started : " << devices[j] << ":"
67-
<< FLAGS_port + i;
68-
pservers[i * devices.size() + j]->start();
69-
}
70-
}
71-
}
72-
73-
for (auto& pserver : pservers) {
74-
pserver->join();
75-
}
24+
std::unique_ptr<PServerUtil> pServerPtr(paddle::PServerUtil::create());
25+
pServerPtr->start();
26+
pServerPtr->join();
7627

7728
return 0;
7829
}

paddle/trainer/TrainerMain.cpp

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <fenv.h>
16-
#include "paddle/pserver/ParameterServer2.h"
16+
#include "paddle/pserver/PServerUtil.h"
1717
#include "paddle/utils/Excepts.h"
1818
#include "paddle/utils/PythonUtil.h"
19-
#include "paddle/utils/StringUtil.h"
2019

2120
#include "ParamUtil.h"
2221
#include "Trainer.h"
23-
#include "paddle/pserver/RDMANetwork.h"
2422

2523
DEFINE_bool(start_pserver, false, "Whether to start pserver");
2624
DECLARE_int32(gpu_id);
@@ -39,54 +37,9 @@ int main(int argc, char** argv) {
3937
initMain(argc, argv);
4038
initPython(argc, argv);
4139

42-
std::vector<std::unique_ptr<ParameterServer2>> pservers;
43-
std::vector<std::string> devices;
44-
4540
if (FLAGS_start_pserver) {
46-
// round robin to loadbalance RDMA server ENGINE
47-
int rdmaCpu = 0;
48-
int onlineCpus = rdma::numCpus();
49-
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
50-
if (FLAGS_nics.empty()) {
51-
pservers.resize(numPorts);
52-
for (int i = 0; i < numPorts; ++i) {
53-
if (FLAGS_rdma_tcp == "rdma") {
54-
pservers[i].reset(
55-
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
56-
rdmaCpu = rdmaCpu % onlineCpus;
57-
} else {
58-
pservers[i].reset(
59-
new ParameterServer2(std::string(), FLAGS_port + i));
60-
}
61-
62-
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
63-
<< FLAGS_port + i;
64-
LOG(INFO) << "pserver started : " << FLAGS_port + i;
65-
pservers[i]->start();
66-
}
67-
} else {
68-
str::split(FLAGS_nics, ',', &devices);
69-
pservers.resize(devices.size() * numPorts);
70-
for (int i = 0; i < numPorts; ++i) {
71-
for (size_t j = 0; j < devices.size(); ++j) {
72-
if (FLAGS_rdma_tcp == "rdma") {
73-
pservers[i * devices.size() + j].reset(new ParameterServer2(
74-
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
75-
rdmaCpu = rdmaCpu % onlineCpus;
76-
} else {
77-
pservers[i * devices.size() + j].reset(
78-
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
79-
}
80-
81-
CHECK(pservers[i * devices.size() + j]->init())
82-
<< "Fail to initialize parameter server" << devices[j]
83-
<< FLAGS_port + i;
84-
LOG(INFO) << "pserver started : " << devices[j] << ":"
85-
<< FLAGS_port + i;
86-
pservers[i * devices.size() + j]->start();
87-
}
88-
}
89-
}
41+
PServerUtil* pServerUtil = paddle::PServerUtil::create();
42+
pServerUtil->start();
9043
}
9144
Trainer trainer;
9245
auto config = TrainerConfigHelper::createFromFlags();

proto/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ set(proto_filenames
44
ModelConfig.proto
55
ParameterConfig.proto
66
ParameterService.proto
7-
TrainerConfig.proto)
7+
TrainerConfig.proto
8+
ParameterServerConfig.proto)
89

910
set(PROTO_GEN)
1011
set(PROTO_GEN_PY)

proto/ParameterServerConfig.proto

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
syntax = "proto2";
15+
16+
package paddle;
17+
18+
message ParameterClientConfig {
19+
required int32 trainer_id = 1;
20+
}
21+
22+
message ParameterServerConfig {
23+
// The ports number for parameter send,
24+
// increment based on default port number
25+
required int32 ports_num = 1 [default = 1];
26+
// The ports number for parameter send,
27+
// increment based on default (port + ports_num
28+
required int32 ports_num_for_sparse = 2 [default = 0];
29+
// network device name for pservers
30+
required string nics = 3 [default = "xgbe0,xgbe1"];
31+
required string rdma_tcp = 4 [default = "tcp"];
32+
// Listening port for pserver
33+
required int32 port = 5 [default = 20134];
34+
// number of gradient servers
35+
required int32 num_gradient_servers = 6 [default = 1];
36+
// number of threads for sync op exec
37+
required int32 pserver_num_threads = 7 [default = 1];
38+
// control config_.async_lagged_grad_discard_ratio() min value
39+
required double async_lagged_ratio_min = 8 [default = 1.0];
40+
// if async_lagged_grad_discard_ratio is not set in trainer_config.conf
41+
// use it as defalut value
42+
required double async_lagged_ratio_default = 9 [default = 1.5];
43+
}

0 commit comments

Comments
 (0)