Skip to content

Commit d13ce35

Browse files
typhoonzerogongweibao
authored andcommitted
Feature/send recv can now retry (#9027)
1 parent 14fe40a commit d13ce35

File tree

8 files changed

+83
-25
lines changed

8 files changed

+83
-25
lines changed

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
9797
return true;
9898
}
9999

100-
bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
100+
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
101101
const auto ch = GetChannel(ep);
102102

103103
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
@@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
108108
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
109109
rpc->Finish(&s->reply_, &s->status_, (void*)s);
110110
req_count_++;
111+
}
111112

112-
return true;
113+
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
114+
const auto ch = GetChannel(ep);
115+
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
116+
s->Prepare(time_out);
117+
118+
sendrecv::VariableMessage req;
119+
req.set_varname(FETCH_BARRIER_MESSAGE);
120+
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
121+
rpc->Finish(&s->reply_, &s->status_, (void*)s);
122+
req_count_++;
113123
}
114124

115125
bool RPCClient::Wait() {
@@ -154,7 +164,7 @@ bool RPCClient::Proceed() {
154164
PADDLE_ENFORCE(tag);
155165

156166
// TODO(gongwb): add more retries.
157-
ClientBase* c = static_cast<ClientBase*>(tag);
167+
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
158168
if (!c->status_.ok()) {
159169
LOG(ERROR) << "proc param error:" << c->var_h_.String()
160170
<< " grpc error:" << c->status_.error_message();
@@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
174184
}
175185

176186
grpc::ChannelArguments args;
187+
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
188+
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
177189
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
178190
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
179191

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ struct VarHandle {
5252
void ProcGetResponse(const VarHandle& var_h,
5353
const sendrecv::VariableMessage& msg);
5454

55-
class ClientBase {
55+
class BaseProcessor {
5656
public:
57-
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
57+
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
5858
stub_ = sendrecv::SendRecvService::NewStub(ch);
5959
context_ = NULL;
6060
}
6161

62-
virtual ~ClientBase() {}
62+
virtual ~BaseProcessor() {}
6363

6464
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
6565
context_.reset(new grpc::ClientContext());
@@ -91,9 +91,10 @@ class ClientBase {
9191
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
9292
RequestSendCallBack;
9393

94-
class SendProcessor : public ClientBase {
94+
class SendProcessor : public BaseProcessor {
9595
public:
96-
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
96+
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
97+
: BaseProcessor(ch) {}
9798

9899
virtual ~SendProcessor() {}
99100

@@ -110,9 +111,10 @@ class SendProcessor : public ClientBase {
110111
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
111112
RequestGetCallBack;
112113

113-
class GetProcessor : public ClientBase {
114+
class GetProcessor : public BaseProcessor {
114115
public:
115-
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
116+
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
117+
: BaseProcessor(ch) {}
116118

117119
virtual ~GetProcessor() {}
118120

@@ -126,17 +128,28 @@ class GetProcessor : public ClientBase {
126128
RequestGetCallBack response_call_back_ = ProcGetResponse;
127129
};
128130

129-
class BatchBarrierProcessor : public ClientBase {
131+
class BatchBarrierProcessor : public BaseProcessor {
130132
public:
131133
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
132-
: ClientBase(ch) {}
134+
: BaseProcessor(ch) {}
133135

134136
virtual ~BatchBarrierProcessor() {}
135137

136138
virtual void Process() {}
137139
sendrecv::VoidMessage reply_;
138140
};
139141

142+
class FetchBarrierProcessor : public BaseProcessor {
143+
public:
144+
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
145+
: BaseProcessor(ch) {}
146+
147+
virtual ~FetchBarrierProcessor() {}
148+
149+
virtual void Process() {}
150+
sendrecv::VariableMessage reply_;
151+
};
152+
140153
class RPCClient {
141154
public:
142155
bool AsyncSendVariable(const std::string& ep,
@@ -151,7 +164,10 @@ class RPCClient {
151164
const std::string& var_name,
152165
int64_t time_out = 600 * 1000);
153166

154-
bool AsyncSendBatchBarrier(const std::string& ep,
167+
void AsyncSendBatchBarrier(const std::string& ep,
168+
int64_t time_out = 600 * 1000);
169+
170+
void AsyncSendFetchBarrier(const std::string& ep,
155171
int64_t time_out = 600 * 1000);
156172

157173
bool Wait();

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class RequestGet final : public RequestBase {
8484
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
8585
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
8686
const platform::DeviceContext* dev_ctx,
87-
SimpleBlockQueue<char>* queue)
87+
SimpleBlockQueue<MessageWithName>* queue)
8888
: RequestBase(service, cq),
8989
responder_(&ctx_),
9090
scope_(scope),
@@ -101,11 +101,16 @@ class RequestGet final : public RequestBase {
101101
// proc request.
102102
std::string var_name = request_.varname();
103103
auto* var = scope_->FindVar(var_name);
104-
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
104+
if (var_name != FETCH_BARRIER_MESSAGE) {
105+
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
106+
}
105107
// TODO(gongwb): check var's info.
106108
responder_.Finish(reply_, grpc::Status::OK, this);
107109
status_ = FINISH;
108-
queue_->Push('c');
110+
MessageWithName msg_with_name =
111+
// request name reply
112+
std::make_pair(var_name, std::move(reply_));
113+
queue_->Push(msg_with_name);
109114
}
110115

111116
protected:
@@ -114,12 +119,16 @@ class RequestGet final : public RequestBase {
114119
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
115120
framework::Scope* scope_;
116121
const platform::DeviceContext* dev_ctx_;
117-
SimpleBlockQueue<char>* queue_;
122+
SimpleBlockQueue<MessageWithName>* queue_;
118123
};
119124

120125
void AsyncGRPCServer::WaitClientGet(int count) {
121-
for (int i = 0; i < count; ++i) {
122-
var_get_queue_.Pop();
126+
int fetch_barriers = 0;
127+
while (fetch_barriers < count) {
128+
auto msg = var_get_queue_.Pop();
129+
if (msg.first == FETCH_BARRIER_MESSAGE) {
130+
fetch_barriers++;
131+
}
123132
}
124133
}
125134

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
7777
const platform::DeviceContext *dev_ctx_;
7878
// received variable from RPC, operators fetch variable from this queue.
7979
SimpleBlockQueue<MessageWithName> var_recv_queue_;
80-
SimpleBlockQueue<char> var_get_queue_;
80+
SimpleBlockQueue<MessageWithName> var_get_queue_;
8181

8282
// condition of the sub program
8383
std::mutex barrier_mutex_;

paddle/fluid/operators/detail/sendrecvop_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace detail {
3232

3333
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
3434
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
35+
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
3536

3637
typedef void (*DestroyCallback)(void*);
3738

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase {
128128
}
129129
}
130130
if (exit_flag) {
131-
rpc_service_->ShutDown();
132131
rpc_service_->SetCond(1);
132+
rpc_service_->ShutDown();
133133
break;
134134
}
135135
try {
@@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase {
148148
}
149149
rpc_service_->SetCond(1);
150150
// FIXME(typhoonzero): use another condition to sync wait clients get.
151-
rpc_service_->WaitClientGet(ins.size());
151+
rpc_service_->WaitClientGet(fan_in);
152152
sparse_vars.clear();
153153
} // while(true)
154154
}

paddle/fluid/operators/send_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase {
8888
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
8989
}
9090
PADDLE_ENFORCE(rpc_client->Wait());
91+
// tell pservers that current trainer have called fetch
92+
for (auto& ep : endpoints) {
93+
VLOG(3) << "send fetch barrier, ep: " << ep;
94+
rpc_client->AsyncSendFetchBarrier(ep);
95+
}
96+
PADDLE_ENFORCE(rpc_client->Wait());
9197
}
9298
}
9399
};

python/paddle/fluid/distribute_transpiler.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def transpile(self,
250250
def get_trainer_program(self):
251251
# remove optimize ops and add a send op to main_program
252252
self.program.global_block().delete_ops(self.optimize_ops)
253+
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
254+
self.program.__str__()
253255
return self.program
254256

255257
def get_pserver_program(self, endpoint):
@@ -309,7 +311,8 @@ def get_pserver_program(self, endpoint):
309311
for _, opt_op in enumerate(opt_op_on_pserver):
310312
if ufind.is_connected(op, opt_op):
311313
if self._is_opt_op(op):
312-
self._append_pserver_ops(optimize_block, op, endpoint)
314+
self._append_pserver_ops(optimize_block, op, endpoint,
315+
default_main_program())
313316
else:
314317
self._append_pserver_non_opt_ops(optimize_block, op)
315318
break
@@ -520,7 +523,8 @@ def _orig_varname(self, varname):
520523
orig_var_name = varname[:suff_idx]
521524
return orig_var_name
522525

523-
def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
526+
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
527+
origin_program):
524528
program = optimize_block.program
525529
pserver_block = program.global_block()
526530
new_inputs = dict()
@@ -576,7 +580,17 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
576580
elif key == "LearningRate":
577581
# leraning rate variable has already be created by non-optimize op,
578582
# don't create it once again.
579-
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
583+
lr_varname = opt_op.input(key)[0]
584+
if pserver_block.vars.has_key(lr_varname):
585+
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
586+
else:
587+
origin_var = origin_program.global_block().vars[lr_varname]
588+
tmpvar = pserver_block.create_var(
589+
name=origin_var.name,
590+
persistable=origin_var.persistable,
591+
dtype=origin_var.dtype,
592+
shape=origin_var.shape)
593+
new_inputs[key] = tmpvar
580594

581595
for key in opt_op.input_names:
582596
new_shape = None

0 commit comments

Comments
 (0)