Skip to content

Commit c33312f

Browse files
authored
bug fix: invalid learning rate decay in pserver async mode (#20325) (#20635)
* bug fix: invalid learning rate decay in pserver async mode
1 parent dcd8e30 commit c33312f

File tree

15 files changed

+296
-8
lines changed

15 files changed

+296
-8
lines changed

paddle/fluid/operators/distributed/grpc/grpc_client.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,35 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
438438
return h;
439439
}
440440

441+
VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep,
442+
const std::string& type,
443+
int64_t time_out) {
444+
const auto ch = GetChannel(ep);
445+
446+
DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch);
447+
448+
const std::string method = kRequestNotify;
449+
450+
VarHandlePtr h(
451+
new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr));
452+
s->Prepare(h, time_out);
453+
454+
sendrecv::VariableMessage req;
455+
req.set_varname(type);
456+
457+
platform::RecordRPCEvent record_event(method);
458+
459+
auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_);
460+
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
461+
req_count_++;
462+
463+
if (UNLIKELY(platform::IsProfileEnabled())) {
464+
h->Wait();
465+
}
466+
467+
return h;
468+
}
469+
441470
bool GRPCClient::Wait() {
442471
std::unique_lock<std::mutex> lk(sync_mutex_);
443472
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });

paddle/fluid/operators/distributed/grpc/grpc_client.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@ class CheckpointNotifyProcessor : public BaseProcessor {
173173
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
174174
};
175175

176+
class DistributeNotifyProcessor : public BaseProcessor {
177+
public:
178+
explicit DistributeNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
179+
: BaseProcessor() {
180+
stub_ = sendrecv::SendRecvService::NewStub(ch);
181+
}
182+
183+
virtual ~DistributeNotifyProcessor() {}
184+
185+
void ProcessImpl() override {}
186+
sendrecv::VoidMessage reply_;
187+
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
188+
};
189+
176190
class GRPCClient : public RPCClient {
177191
public:
178192
GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
@@ -225,6 +239,10 @@ class GRPCClient : public RPCClient {
225239
const std::string& ep, const std::string& dir,
226240
int64_t time_out = FLAGS_rpc_deadline) override;
227241

242+
VarHandlePtr AsyncDistributeNotify(
243+
const std::string& ep, const std::string& type,
244+
int64_t time_out = FLAGS_rpc_deadline) override;
245+
228246
VarHandlePtr AsyncSendComplete(
229247
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
230248

paddle/fluid/operators/distributed/grpc/grpc_server.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,43 @@ class RequestCheckpointNotify final : public RequestBase {
393393
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
394394
};
395395

396+
class RequestNotify final : public RequestBase {
397+
public:
398+
explicit RequestNotify(GrpcService::AsyncService* service,
399+
::grpc::ServerCompletionQueue* cq,
400+
RequestHandler* request_handler, int req_id)
401+
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
402+
request_.reset(new GRPCVariableResponse(request_handler->scope(),
403+
request_handler->dev_ctx()));
404+
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
405+
service_->RequestAsyncUnary(
406+
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
407+
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
408+
}
409+
410+
virtual ~RequestNotify() {}
411+
412+
std::string GetReqName() override { return request_->Varname(); }
413+
414+
void Process() override {
415+
auto scope = request_->GetMutableLocalScope();
416+
417+
std::string varname = request_->Varname();
418+
int trainer_id = request_->GetTrainerId();
419+
420+
VLOG(4) << "RequestNotify notify: " << varname
421+
<< ", trainer id: " << trainer_id;
422+
423+
request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id);
424+
Finish(reply_, &responder_);
425+
}
426+
427+
protected:
428+
std::shared_ptr<GRPCVariableResponse> request_;
429+
sendrecv::VoidMessage reply_;
430+
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
431+
};
432+
396433
void AsyncGRPCServer::WaitServerReady() {
397434
VLOG(4) << "AsyncGRPCServer is waiting server ready";
398435
std::unique_lock<std::mutex> lock(this->mutex_ready_);
@@ -526,6 +563,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
526563
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
527564
} else if (rpc_name == kRequestCheckpoint) {
528565
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
566+
} else if (rpc_name == kRequestNotify) {
567+
b = new RequestNotify(&service_, cq.get(), handler, req_id);
529568
} else {
530569
PADDLE_ENFORCE(false, "not supported rpc");
531570
}

paddle/fluid/operators/distributed/grpc/grpc_service.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ enum class GrpcMethod {
8484
kGetVariableNoBarrier,
8585
kGetMonomerVariable,
8686
kGetMonomerBarrier,
87+
kRequestNotify,
8788
};
8889

8990
static const int kGrpcNumMethods =
90-
static_cast<int>(GrpcMethod::kGetMonomerBarrier) + 1;
91+
static_cast<int>(GrpcMethod::kRequestNotify) + 1;
9192

9293
inline const char* GrpcMethodName(GrpcMethod id) {
9394
switch (id) {
@@ -105,6 +106,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
105106
return "/sendrecv.SendRecvService/PrefetchVariable";
106107
case GrpcMethod::kCheckpointNotify:
107108
return "/sendrecv.SendRecvService/CheckpointNotify";
109+
case GrpcMethod::kRequestNotify:
110+
return "/sendrecv.SendRecvService/DistributeNotify";
108111
}
109112

110113
// Shouldn't be reached.

paddle/fluid/operators/distributed/request_handler.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
4545
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
4646
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
4747
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
48+
constexpr char kRequestNotify[] = "RequestNotify";
4849

4950
constexpr char kSendRPC[] = "SendRPC";
5051
constexpr char kGetRPC[] = "GetRPC";
@@ -62,6 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
6263
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
6364
#define COMPLETE_MESSAGE "COMPLETE@RECV"
6465
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
66+
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV"
6567

6668
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
6769
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
@@ -188,6 +190,11 @@ class RequestHandler {
188190
sparse_grad_to_param_ = g;
189191
}
190192

193+
void SetLrDecayPreparedCtx(
194+
std::shared_ptr<framework::ExecutorPrepareContext> g) {
195+
lr_decay_prepared_ctx_ = g;
196+
}
197+
191198
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
192199

193200
// Get attributes.
@@ -238,6 +245,8 @@ class RequestHandler {
238245
grad_to_prepared_ctx_;
239246
std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
240247

248+
// used for lr decay
249+
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
241250
RPCServer* rpc_server_;
242251
};
243252

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,23 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
251251
return true;
252252
}
253253

254+
bool RequestNotifyHandler::Handle(const std::string& varname,
255+
framework::Scope* scope,
256+
framework::Variable* invar,
257+
framework::Variable** outvar,
258+
const int trainer_id,
259+
const std::string& out_var_name,
260+
const std::string& table_name) {
261+
VLOG(4) << "RequestNotifyHandler" << varname;
262+
if (varname == LEARNING_RATE_DECAY_MESSAGE) {
263+
PADDLE_ENFORCE_NE(
264+
lr_decay_block_id, -1,
265+
"when lr_decay_block_id = -1, there should be no RPC invoke.");
266+
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
267+
}
268+
return true;
269+
}
270+
254271
} // namespace distributed
255272
} // namespace operators
256273
} // namespace paddle

paddle/fluid/operators/distributed/request_handler_impl.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <time.h>
1818

1919
#include <functional>
20+
#include <memory>
2021
#include <string>
2122
#include <utility>
2223
#include <vector>
@@ -126,6 +127,22 @@ class RequestCheckpointHandler final : public RequestHandler {
126127
int checkpoint_notify_id;
127128
};
128129

130+
class RequestNotifyHandler final : public RequestHandler {
131+
public:
132+
explicit RequestNotifyHandler(bool sync_mode, int lr_decay_block_id)
133+
: RequestHandler(sync_mode) {
134+
this->lr_decay_block_id = lr_decay_block_id;
135+
}
136+
virtual ~RequestNotifyHandler() {}
137+
bool Handle(const std::string& varname, framework::Scope* scope,
138+
framework::Variable* var, framework::Variable** outvar,
139+
const int trainer_id, const std::string& out_var_name = "",
140+
const std::string& table_name = "") override;
141+
142+
private:
143+
int lr_decay_block_id;
144+
};
145+
129146
} // namespace distributed
130147
} // namespace operators
131148
} // namespace paddle

paddle/fluid/operators/distributed/rpc_client.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class RPCClient {
8080
const std::string& ep, const std::string& dir,
8181
int64_t time_out = FLAGS_rpc_deadline) = 0;
8282

83+
virtual VarHandlePtr AsyncDistributeNotify(
84+
const std::string& ep, const std::string& type,
85+
int64_t time_out = FLAGS_rpc_deadline) = 0;
86+
8387
virtual VarHandlePtr AsyncSendComplete(
8488
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
8589

paddle/fluid/operators/distributed/send_recv.proto.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ service SendRecvService {
2828
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
2929

3030
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
31+
rpc DistributeNotify(VariableMessage) returns (VoidMessage) {}
3132

3233
rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {}
3334
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <future> // NOLINT
16+
#include <ostream>
17+
18+
#include "paddle/fluid/framework/data_type.h"
19+
#include "paddle/fluid/framework/lod_tensor.h"
20+
#include "paddle/fluid/framework/op_registry.h"
21+
#include "paddle/fluid/operators/distributed/distributed.h"
22+
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
23+
#include "paddle/fluid/string/printf.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
class DistributedNotifyOp : public framework::OperatorBase {
29+
public:
30+
DistributedNotifyOp(const std::string& type,
31+
const framework::VariableNameMap& inputs,
32+
const framework::VariableNameMap& outputs,
33+
const framework::AttributeMap& attrs)
34+
: OperatorBase(type, inputs, outputs, attrs) {}
35+
36+
void RunImpl(const framework::Scope& scope,
37+
const platform::Place& place) const override {
38+
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
39+
std::string type = Attr<std::string>("type");
40+
int trainer_id = Attr<int>("trainer_id");
41+
42+
distributed::RPCClient* rpc_client =
43+
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
44+
for (size_t i = 0; i < epmap.size(); i++) {
45+
rpc_client->AsyncDistributeNotify(epmap[i], type);
46+
VLOG(4) << "distribute notify sending : " << type << " to " << epmap[i];
47+
}
48+
PADDLE_ENFORCE_EQ(rpc_client->Wait(), true, "internal error in RPCClient");
49+
}
50+
};
51+
52+
class DistributedNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
53+
public:
54+
void Make() {
55+
AddAttr<std::vector<std::string>>("epmap",
56+
"(string vector, default 127.0.0.1:6164)"
57+
"Parameter Server endpoints in the order")
58+
.SetDefault({"127.0.0.1:6164"});
59+
AddAttr<std::string>("type",
60+
"(string, default '') indicate the action type");
61+
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
62+
AddComment(R"DOC(
63+
DistributeNotify operator
64+
65+
This operator will send a signal to listen_and_serve op at
66+
the parameter server.
67+
)DOC");
68+
}
69+
};
70+
71+
class DistributedNotifyOpShapeInference : public framework::InferShapeBase {
72+
public:
73+
void operator()(framework::InferShapeContext* ctx) const override {}
74+
};
75+
76+
} // namespace operators
77+
} // namespace paddle
78+
79+
namespace ops = paddle::operators;
80+
81+
REGISTER_OPERATOR(distributed_notify, ops::DistributedNotifyOp,
82+
paddle::framework::EmptyGradOpMaker,
83+
ops::DistributedNotifyOpMaker,
84+
ops::DistributedNotifyOpShapeInference);

0 commit comments

Comments
 (0)