Skip to content

Commit efd5a84

Browse files
committed
update executor interface
1 parent 800702c commit efd5a84

File tree

6 files changed

+39
-13
lines changed

6 files changed

+39
-13
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
4545

4646
Executor::Executor(const platform::Place& place) : place_(place) {}
4747

48+
void Executor::Close() {
4849
#ifdef PADDLE_WITH_DISTRIBUTE
49-
void Executor::Complete() {
5050
::paddle::operators::distributed::RPCClient::GetInstance<
5151
::paddle::operators::distributed::GRPCClient>()
5252
->SendComplete();
53-
}
5453
#endif
54+
}
5555

5656
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
5757
if (var_type == proto::VarType::LOD_TENSOR) {

paddle/fluid/framework/executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Executor {
4848
/*
4949
* Sending signal to pserver to mark current trainer completed.
5050
*/
51-
void Complete();
51+
void Close();
5252

5353
#endif
5454

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,15 @@ void GRPCClient::InitEventLoop() {
3636
}
3737

3838
void GRPCClient::SendComplete() {
39-
for (auto& it : channels_) {
40-
VLOG(3) << "send complete message to " << it.first;
41-
this->AsyncSendComplete(it.first);
39+
std::unique_lock<std::mutex> lk(completed_mutex_);
40+
if (!completed_) {
41+
for (auto& it : channels_) {
42+
VLOG(3) << "send complete message to " << it.first;
43+
this->AsyncSendComplete(it.first);
44+
}
45+
PADDLE_ENFORCE(this->Wait(), "internal grpc error");
46+
completed_ = true;
4247
}
43-
this->Wait();
4448
}
4549

4650
GRPCClient::~GRPCClient() {

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
188188

189189
class GRPCClient : public RPCClient {
190190
public:
191-
GRPCClient() : ok_(true) {}
191+
GRPCClient() : ok_(true), completed_(false) {}
192192
virtual ~GRPCClient();
193193

194194
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
@@ -247,6 +247,10 @@ class GRPCClient : public RPCClient {
247247
// mutex for GetChannel thread safety
248248
std::mutex chan_mutex_;
249249
DISABLE_COPY_AND_ASSIGN(GRPCClient);
250+
251+
// mutex for sending complete message only once
252+
std::mutex completed_mutex_;
253+
bool completed_;
250254
};
251255

252256
} // namespace distributed

paddle/fluid/pybind/pybind.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,7 @@ All parameter, weight, gradient are variables in Paddle.
502502

503503
py::class_<framework::Executor>(m, "Executor")
504504
.def(py::init<const platform::Place &>())
505-
#ifdef PADDLE_WITH_DISTRIBUTE
506-
.def("complete", &Executor::Complete)
507-
#endif
505+
.def("close", &Executor::Close)
508506
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
509507
int block_id, bool create_local_scope, bool create_vars) {
510508
pybind11::gil_scoped_release release;

python/paddle/fluid/executor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def __init__(self, place):
247247
p.set_place(place)
248248
self.executor = core.Executor(p)
249249
self.program_caches = dict()
250+
self._closed = False
250251

251252
def as_lodtensor(self, data):
252253
"""
@@ -348,8 +349,23 @@ def _fetch_data(self, fetch_list, fetch_var_name, scope):
348349
]
349350
return outs
350351

351-
def complete(self):
352-
self.executor.complete()
352+
def close(self):
353+
"""
354+
Close this executor.
355+
356+
You can no long use this executor after calling this method.
357+
For the distributed training, this method would free the resource on PServers related to
358+
the current Trainer.
359+
360+
Example:
361+
>>> cpu = core.CPUPlace()
362+
>>> exe = Executor(cpu)
363+
>>> ...
364+
>>> exe.close()
365+
"""
366+
if not self._closed:
367+
self.executor.close()
368+
self._closed = True
353369

354370
def run(self,
355371
program=None,
@@ -402,6 +418,10 @@ def run(self,
402418
>>> feed={'X': x},
403419
>>> fetch_list=[loss.name])
404420
"""
421+
422+
if self._closed:
423+
raise RuntimeError("Attempted to use a closed Executor")
424+
405425
if feed is None:
406426
feed = {}
407427
if not isinstance(feed, dict):

0 commit comments

Comments
 (0)