Skip to content

Commit a0ced3d

Browse files
committed
async update can run
1 parent 34f2818 commit a0ced3d

File tree

5 files changed

+33
-24
lines changed

5 files changed

+33
-24
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
315315
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
316316

317317
PADDLE_ENFORCE(tag);
318-
// FIXME(typhoonzero): de-couple the barriers with recv_op
319-
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
320-
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
318+
if (sync_mode_) {
319+
// FIXME(typhoonzero): de-couple the barriers with recv_op
320+
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
321+
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
322+
}
321323

322324
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
323325
// reference:
@@ -334,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
334336

335337
switch (base->Status()) {
336338
case PROCESS: {
337-
VLOG(4) << cq_name << " status:" << base->Status();
339+
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
338340
TryToRegisterNewOne();
339341
base->Process();
340342
break;
341343
}
342344
case FINISH: {
343-
VLOG(4) << cq_name << " status:" << base->Status();
345+
VLOG(4) << cq_name << " FINISH status:" << base->Status();
344346
delete base;
345347
break;
346348
}

paddle/fluid/operators/detail/variable_response.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class VariableResponse {
6161
// other: number of error field.
6262
int Parse(const ::grpc::ByteBuffer& byte_buffer);
6363

64-
const framework::Scope& GetLocalScope() const { return *local_scope_; }
64+
framework::Scope& GetLocalScope() const { return *local_scope_; }
6565

6666
inline std::string Varname() { return meta_.varname(); }
6767
inline std::string OutVarname() { return meta_.out_varname(); }

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@ static void split(const std::string &str, char sep,
4848
static void AsyncExecuteBlock(framework::Executor *executor,
4949
framework::ExecutorPrepareContext *prepared,
5050
framework::Scope *scope) {
51-
framework::Async([&executor, &prepared, &scope]() {
51+
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
5252
try {
5353
executor->RunPreparedContext(prepared, scope, false, false);
5454
} catch (std::exception &e) {
5555
LOG(ERROR) << "run sub program error " << e.what();
5656
}
5757
});
58+
// TODO(qiao) maybe we can remove this
59+
future.wait();
5860
}
5961

6062
static void ParallelExecuteBlocks(
@@ -203,14 +205,16 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
203205
framework::ProgramDesc *program,
204206
framework::Scope *recv_scope,
205207
framework::BlockDesc *prefetch_block) const {
208+
VLOG(3) << "RunAsyncLoop in";
206209
// grad name to block id
207210
std::unordered_map<std::string, int32_t> grad_to_id;
208211
std::unordered_map<int32_t, std::string> id_to_grad;
209212

210213
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
211214
for (auto &grad_and_id : grad_to_id_str) {
212215
std::vector<std::string> pieces;
213-
split(grad_and_id, ' ', &pieces);
216+
split(grad_and_id, ':', &pieces);
217+
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
214218
PADDLE_ENFORCE_EQ(pieces.size(), 2);
215219
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
216220
int block_id = std::stoi(pieces[1]);
@@ -223,21 +227,17 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
223227

224228
std::vector<int> block_list;
225229
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
226-
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
227-
block_list.push_back(blkid);
228-
}
230+
block_list.push_back(blkid);
229231
}
230-
PADDLE_ENFORCE_EQ(grad_to_id_str.size(), block_list.size(),
231-
"grad num should be equal to optimize block num");
232232
auto optimize_prepared = executor->Prepare(*program, block_list);
233-
234233
std::unordered_map<std::string,
235234
std::shared_ptr<framework::ExecutorPrepareContext>>
236235
grad_to_prepared;
237236
for (size_t i = 0; i < block_list.size(); ++i) {
238237
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
239238
}
240239

240+
VLOG(3) << "RunAsyncLoop into while";
241241
bool exit_flag = false;
242242
while (!exit_flag) {
243243
const detail::ReceivedMessage v = rpc_service_->Get();
@@ -254,7 +254,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
254254
PADDLE_THROW("Can not find server side var");
255255
}
256256
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
257-
recv_scope);
257+
&(v.second->GetLocalScope()));
258258
// TODO(qiao): explain why
259259
if (var->IsType<framework::SelectedRows>()) {
260260
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();

paddle/fluid/operators/send_op.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase {
4141
std::vector<std::string> endpoints =
4242
Attr<std::vector<std::string>>("endpoints");
4343

44+
bool sync_mode = Attr<bool>("sync_mode");
45+
4446
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4547
auto& ctx = *pool.Get(place);
4648

@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase {
6466
}
6567
PADDLE_ENFORCE(rpc_client->Wait());
6668

67-
for (auto& ep : endpoints) {
68-
VLOG(3) << "batch barrier, ep: " << ep;
69-
rpc_client->AsyncSendBatchBarrier(ep);
69+
if (sync_mode) {
70+
for (auto& ep : endpoints) {
71+
VLOG(3) << "batch barrier, ep: " << ep;
72+
rpc_client->AsyncSendBatchBarrier(ep);
73+
}
74+
PADDLE_ENFORCE(rpc_client->Wait());
7075
}
71-
PADDLE_ENFORCE(rpc_client->Wait());
7276

7377
if (outs.size() > 0) {
7478
for (size_t i = 0; i < outs.size(); i++) {
@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server.
112116
"Server endpoints in the order of input "
113117
"variables for mapping")
114118
.SetDefault({});
119+
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
115120
}
116121
};
117122

python/paddle/fluid/distribute_transpiler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,11 @@ def transpile(self,
297297
inputs={"X": send_inputs},
298298
outputs={"Out": send_outputs,
299299
"RPCClient": rpc_client_var},
300-
attrs={"endpoints": pserver_endpoints,
301-
"epmap": eplist})
300+
attrs={
301+
"endpoints": pserver_endpoints,
302+
"epmap": eplist,
303+
"sync_mode": self.sync_mode
304+
})
302305
# step4: Concat the parameters splits together after recv.
303306
for varname, splited_var in param_var_mapping.iteritems():
304307
if len(splited_var) <= 1:
@@ -404,8 +407,8 @@ def get_pserver_program(self, endpoint):
404407
for op in self.optimize_ops:
405408
if op.type == "scale":
406409
for in_name in op.input_arg_names:
407-
if in_name.startswith("beta1_pow_acc") or\
408-
in_name.startswith("beta2_pow_acc"):
410+
if in_name.startswith("beta1_pow_acc") or \
411+
in_name.startswith("beta2_pow_acc"):
409412
global_ops.append(op)
410413

411414
def __append_optimize_op__(op, block, grad_to_block_id):
@@ -434,7 +437,6 @@ def __append_optimize_op__(op, block, grad_to_block_id):
434437
__append_optimize_op__(op, per_opt_block, grad_to_block_id)
435438

436439
# append global ops
437-
opt_state_block = None
438440
if global_ops:
439441
opt_state_block = pserver_program.create_block(
440442
pserver_program.num_blocks - 1)

0 commit comments

Comments
 (0)