Skip to content

Commit 5f4d913

Browse files
committed
merge codes
1 parent ae19d2e commit 5f4d913

File tree

3 files changed

+10
-16
lines changed

3 files changed

+10
-16
lines changed

paddle/operators/detail/grpc_server.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
162162
}
163163

164164
// This URL explains why shutdown is complicate:
165-
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
166165
void AsyncGRPCServer::ShutDown() {
167166
server_->Shutdown();
168167
ShutdownQueue();
@@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
188187
VLOG(4) << "create Requestget status:" << get->Status();
189188
}
190189

190+
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
191191
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
192192
std::string cq_name,
193193
std::function<void()> TryToRegisterNewOne) {
@@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
202202
}
203203

204204
PADDLE_ENFORCE(tag);
205-
if (cq_name == "cq_get") WaitCond(2);
205+
// FIXME(typhoonzero): de-couple the barriers with recv_op
206+
if (cq_name == "cq_get") WaitCond(1);
206207
if (cq_name == "cq_send") WaitCond(0);
207208

208209
RequestBase* base = (RequestBase*)tag;

paddle/operators/detail/grpc_server.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
4242
void RunSyncUpdate();
4343

4444
// functions to sync server barrier status.
45-
void WaitStart();
46-
void WaitDone();
47-
void Start();
48-
void Done();
45+
void WaitCond(int cond);
46+
void SetCond(int cond);
4947
void WaitClientGet(int count);
5048

5149
void SetScope(framework::Scope *scope) { scope_ = scope; }

paddle/operators/recv_op.cc

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase {
105105
framework::ProgramDesc program(program_desc);
106106
framework::Executor executor(dev_place);
107107

108-
// rpc_service_->Reset();
109108
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
110109
bool exit_flag = false;
110+
int64_t barrier_size = param_count * fan_in;
111111
while (!exit_flag) {
112112
// Get from multiple trainers, we don't care about the order in which
113113
// the gradients arrives, just add suffix 0~n and merge the gradient.
114-
rpc_service_->SetCond(kCondStart);
115-
VLOG(3) << "================ start get from service ===========";
116-
for (size_t i = 0; i < param_count * fan_in; ++i) {
114+
rpc_service_->SetCond(0);
115+
for (size_t i = 0; i < barrier_size; ++i) {
117116
const detail::MessageWithName &v = rpc_service_->Get();
118117
auto grad_var_name = v.first;
119118
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
@@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase {
130129
}
131130
VLOG(3) << "recved grad: " << grad_var_name
132131
<< " updating param: " << param_var_name;
133-
// Assume grad_var_name must appear in global scope.
134-
std::string grad_var_name_trainer;
135132
if (fan_in > 1) {
136133
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
137134
}
@@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase {
145142
if (exit_flag) {
146143
break;
147144
}
148-
// rpc_service_->Reset();
149145
try {
150146
executor.Run(program, &recv_scope, 0, /*global_block*/
151147
false /*create_local_scope*/, false /*create_vars*/);
152148
} catch (std::exception &e) {
153149
LOG(ERROR) << "run sub program error " << e.what();
154150
}
155-
VLOG(3) << "================ run sub program end ===========";
156-
rpc_service_->SetCond(kCondDone);
157-
rpc_service_->WaitClientGet(param_count * fan_in);
151+
rpc_service_->SetCond(1);
152+
rpc_service_->WaitClientGet(barrier_size);
158153
grads_counter_.clear();
159154
} // while(true)
160155
}

0 commit comments

Comments
 (0)