File tree Expand file tree Collapse file tree 4 files changed +18
-5
lines changed Expand file tree Collapse file tree 4 files changed +18
-5
lines changed Original file line number Diff line number Diff line change @@ -178,6 +178,7 @@ namespace std {
178
178
%newobject ParameterOptimizer::create;
179
179
%newobject ParameterOptimizer::needSpecialTraversal;
180
180
%newobject ParameterUpdater::createLocalUpdater;
181
+ %newobject ParameterUpdater::createRemoteUpdater;
181
182
182
183
%feature("director") UpdateCallback;
183
184
%feature("autodoc", 1); // To generate method stub, for code hint in ide
Original file line number Diff line number Diff line change @@ -803,6 +803,8 @@ class ParameterUpdater {
803
803
804
804
public:
805
805
static ParameterUpdater* createLocalUpdater (OptimizationConfig* config);
806
+ static ParameterUpdater* createRemoteUpdater (OptimizationConfig* config,
807
+ int passCount);
806
808
~ParameterUpdater ();
807
809
808
810
/* *
Original file line number Diff line number Diff line change @@ -15,15 +15,25 @@ limitations under the License. */
15
15
#include " PaddleAPI.h"
16
16
17
17
#include " PaddleAPIPrivate.h"
18
+ #include " paddle/trainer/RemoteParameterUpdater.h"
18
19
#include " paddle/trainer/ThreadParameterUpdater.h"
19
20
20
21
ParameterUpdater::ParameterUpdater () : m(new ParameterUpdaterPrivate()) {}
21
22
22
23
ParameterUpdater *ParameterUpdater::createLocalUpdater (
23
24
OptimizationConfig *config) {
24
- auto param = new ParameterUpdater ();
25
- param->m ->updater .reset (new paddle::SgdThreadUpdater (config->m ->getConfig ()));
26
- return param;
25
+ auto updater = new ParameterUpdater ();
26
+ updater->m ->updater .reset (
27
+ new paddle::SgdThreadUpdater (config->m ->getConfig ()));
28
+ return updater;
29
+ }
30
+
31
+ ParameterUpdater *ParameterUpdater::createRemoteUpdater (
32
+ OptimizationConfig *config, int passCount) {
33
+ auto updater = new ParameterUpdater ();
34
+ updater->m ->updater .reset (new paddle::RemoteParameterUpdater (
35
+ config->m ->getConfig (), passCount, nullptr ));
36
+ return updater;
27
37
}
28
38
29
39
ParameterUpdater::~ParameterUpdater () { delete m; }
Original file line number Diff line number Diff line change @@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
56
56
public:
57
57
RemoteParameterUpdater (
58
58
const OptimizationConfig& config,
59
- int expectedPpassCount ,
59
+ int expectedPassCount ,
60
60
std::unique_ptr<ParameterUpdater>&& localUpdater = nullptr );
61
61
~RemoteParameterUpdater () {
62
62
if (controllerThread_) {
@@ -146,7 +146,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
146
146
BatchStatus batchStatus_;
147
147
// / controller thread for sync-sgd
148
148
std::unique_ptr<std::thread> controllerThread_;
149
- // / passed alread finished
149
+ // / passed already finished
150
150
int64_t passCount_;
151
151
// / expected passes to finished
152
152
int64_t expectedPassCount_;
You can’t perform that action at this time.
0 commit comments