Skip to content

Commit fc06222

Browse files
committed
fix async worker
1 parent 540b453 commit fc06222

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

paddle/fluid/operators/send_barrier_op.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase {
3737
void RunImpl(const framework::Scope& scope,
3838
const platform::Place& place) const override {
3939
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
40+
bool sync_mode = Attr<bool>("sync_mode");
4041

4142
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4243
auto& ctx = *pool.Get(place);
@@ -51,12 +52,13 @@ class SendBarrierOp : public framework::OperatorBase {
5152

5253
// need to wait before sending send_barrier message
5354
PADDLE_ENFORCE(rpc_client->Wait());
54-
55-
for (auto& ep : eps) {
56-
VLOG(3) << "send barrier, ep: " << ep;
57-
rpc_client->AsyncSendBatchBarrier(ep);
55+
if (sync_mode) {
56+
for (auto& ep : eps) {
57+
VLOG(3) << "send barrier, ep: " << ep;
58+
rpc_client->AsyncSendBatchBarrier(ep);
59+
}
60+
PADDLE_ENFORCE(rpc_client->Wait());
5861
}
59-
PADDLE_ENFORCE(rpc_client->Wait());
6062
}
6163
};
6264

@@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent.
7779
"(string vector, default 127.0.0.1:6164)"
7880
"Server endpoints to send variables to.")
7981
.SetDefault({"127.0.0.1:6164"});
82+
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
8083
}
8184
};
8285

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def net_conf(self):
4949
def test_transpiler(self):
5050
trainer = self.get_trainer()
5151
pserver, startup = self.get_pserver(self.current_pserver_ep)
52-
5352
self.assertEqual([op.type for op in trainer.global_block().ops],
5453
self.get_expect_trainer_ops())
5554

@@ -67,7 +66,7 @@ def test_transpiler(self):
6766
"fill_constant", "fill_constant", "uniform_random", "uniform_random"
6867
])
6968

70-
# the variable #fc_w will be split into two blocks
69+
# the variable #fc_w will be split into two blocks
7170
fc_w_var = startup.global_block().var("fc_w.block1")
7271
self.assertEqual(fc_w_var.shape, (500, 1000))
7372

@@ -86,8 +85,12 @@ def get_expect_trainer_ops(self):
8685
optimize_ops, params_grads = self.net_conf()
8786

8887
delete_ops(trainer.global_block(), optimize_ops)
89-
return [op.type for op in trainer.global_block().ops
90-
] + ["split_byref", "send", "concat"]
88+
ops = [op.type for op in trainer.global_block().ops] + [
89+
"split_byref", "send_vars", "send_barrier", "recv", "recv",
90+
"fetch_barrier", "concat"
91+
]
92+
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
93+
return ops
9194

9295
def get_trainer(self):
9396
return self._transpiler_instance().get_trainer_program()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,10 @@ def transpile(self,
348348
type="send_barrier",
349349
inputs={},
350350
outputs={"RPCClient": rpc_client_var},
351-
attrs={"endpoints": pserver_endpoints})
351+
attrs={
352+
"endpoints": pserver_endpoints,
353+
"sync_mode": self.sync_mode
354+
})
352355

353356
# step 3.2: insert recv op to receive parameters from parameter server
354357
recv_vars = []

python/paddle/fluid/transpiler/ps_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class PSDispatcher(object):
1717
"""
18-
DistributedSpliter is the base class for dispatching vars
18+
PSDispatcher is the base class for dispatching vars
1919
into different pserver instance.
2020
You need to implement the `dispatch` inferface.
2121
"""

0 commit comments

Comments
 (0)