Skip to content

Commit e4156f6

Browse files
authored
Merge pull request #1018 from jacquesqiao/remote-updater
[in progress]add RemoteUpdater in api for cluster training
2 parents be3e276 + cee9944 commit e4156f6

File tree

4 files changed

+18
-5
lines changed

4 files changed

+18
-5
lines changed

paddle/api/Paddle.swig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ namespace std {
178178
%newobject ParameterOptimizer::create;
179179
%newobject ParameterOptimizer::needSpecialTraversal;
180180
%newobject ParameterUpdater::createLocalUpdater;
181+
%newobject ParameterUpdater::createRemoteUpdater;
181182

182183
%feature("director") UpdateCallback;
183184
%feature("autodoc", 1); // To generate method stub, for code hint in ide

paddle/api/PaddleAPI.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,8 @@ class ParameterUpdater {
803803

804804
public:
805805
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
806+
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
807+
int passCount);
806808
~ParameterUpdater();
807809

808810
/**

paddle/api/ParameterUpdater.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,25 @@ limitations under the License. */
1515
#include "PaddleAPI.h"
1616

1717
#include "PaddleAPIPrivate.h"
18+
#include "paddle/trainer/RemoteParameterUpdater.h"
1819
#include "paddle/trainer/ThreadParameterUpdater.h"
1920

2021
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
2122

2223
ParameterUpdater *ParameterUpdater::createLocalUpdater(
2324
OptimizationConfig *config) {
24-
auto param = new ParameterUpdater();
25-
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
26-
return param;
25+
auto updater = new ParameterUpdater();
26+
updater->m->updater.reset(
27+
new paddle::SgdThreadUpdater(config->m->getConfig()));
28+
return updater;
29+
}
30+
31+
ParameterUpdater *ParameterUpdater::createRemoteUpdater(
32+
OptimizationConfig *config, int passCount) {
33+
auto updater = new ParameterUpdater();
34+
updater->m->updater.reset(new paddle::RemoteParameterUpdater(
35+
config->m->getConfig(), passCount, nullptr));
36+
return updater;
2737
}
2838

2939
ParameterUpdater::~ParameterUpdater() { delete m; }

paddle/trainer/RemoteParameterUpdater.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
5656
public:
5757
RemoteParameterUpdater(
5858
const OptimizationConfig& config,
59-
int expectedPpassCount,
59+
int expectedPassCount,
6060
std::unique_ptr<ParameterUpdater>&& localUpdater = nullptr);
6161
~RemoteParameterUpdater() {
6262
if (controllerThread_) {
@@ -146,7 +146,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
146146
BatchStatus batchStatus_;
147147
/// controller thread for sync-sgd
148148
std::unique_ptr<std::thread> controllerThread_;
149-
/// passed alread finished
149+
/// passed already finished
150150
int64_t passCount_;
151151
/// expected passes to finished
152152
int64_t expectedPassCount_;

0 commit comments

Comments
 (0)