Skip to content

Commit 580340e

Browse files
velconiatyphoonzero
authored andcommitted
Shutdown pserver gracefully when SIGINT and SIGTERM was sent (#10984)
* 1. implement StopAll in ListenAndServOp 2. make pserver receive the SIGINT and SIGTERM from outside 3. add unittests for listen_and_serv_op in python * 1. add blocking queue set to record queue 2. aware all blocking queue when exit and exit gracefully * 1. Remove comment lines from blocking_queue.h 2. Implement SignalHandler and move all global vars and funcs into it * 1. Make code follows the style check 2. Move the SignalHandler out of the unnamed namespace * 1. Make yapf happy * 1. Call Stop() in destructor to release the resource allocated by ListendAndServOp 2. Change exit status to EXIT_SUCCESS after handling the signal from outside 3. Remove the mis-usage of REMOVE_ITEM in unittests * 1. use DISABLE_COPY_AND_ASSIGN 2. use program once macro only
1 parent 3d6934e commit 580340e

File tree

4 files changed

+185
-18
lines changed

4 files changed

+185
-18
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <stdio.h> // for removing the port file
16+
#include <csignal>
17+
#include <cstdlib>
1618
#include <fstream>
17-
#include <ostream>
1819
#include <thread> // NOLINT
1920
#include <vector>
2021

@@ -28,7 +29,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
2829
service->RunSyncUpdate();
2930
VLOG(4) << "RunServer thread end";
3031
}
31-
3232
static void split(const std::string &str, char sep,
3333
std::vector<std::string> *pieces) {
3434
pieces->clear();
@@ -59,7 +59,7 @@ static void ParallelExecuteBlocks(
5959
int run_block = idx; // thread local
6060
try {
6161
executor->RunPreparedContext(prepared[run_block].get(), scope);
62-
} catch (std::exception &e) {
62+
} catch (const std::exception &e) {
6363
LOG(ERROR) << "run sub program error " << e.what();
6464
}
6565
}));
@@ -75,8 +75,11 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
7575
const framework::AttributeMap &attrs)
7676
: OperatorBase(type, inputs, outputs, attrs) {}
7777

78+
ListenAndServOp::~ListenAndServOp() { Stop(); }
79+
7880
void ListenAndServOp::Stop() {
7981
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
82+
rpc_service_->ShutDown();
8083
server_thread_->join();
8184
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
8285
remove(file_path.c_str());
@@ -122,7 +125,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
122125
// Record received sparse variables, so that
123126
// we could reset those after execute optimize program
124127
std::vector<framework::Variable *> sparse_vars;
125-
while (!exit_flag) {
128+
while (!exit_flag && !SignalHandler::IsProgramExit()) {
126129
// Get from multiple trainers, we don't care about the order in which
127130
// the gradients arrives, just add suffix 0~n and merge the gradient.
128131
rpc_service_->SetCond(0);
@@ -187,7 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
187190
// mini-batch.
188191
// TODO(Yancey1989): move the reset action into an operator, we couldn't
189192
// have any hide logic in the operator.
190-
for (auto &var : sparse_vars) {
193+
for (framework::Variable *var : sparse_vars) {
191194
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
192195
}
193196

@@ -204,8 +207,12 @@ static void AsyncUpdateThread(
204207
framework::Executor *executor,
205208
framework::ExecutorPrepareContext *prepared) {
206209
VLOG(3) << "update thread for " << var_name << " started";
207-
while (!exit_flag) {
210+
while (!exit_flag && !SignalHandler::IsProgramExit()) {
208211
const detail::ReceivedMessage v = queue->Pop();
212+
if (SignalHandler::IsProgramExit()) {
213+
VLOG(3) << "update thread for " << var_name << " exit";
214+
break;
215+
}
209216
auto recv_var_name = v.first;
210217
VLOG(4) << "async update " << recv_var_name;
211218
auto var = v.second->GetVar();
@@ -217,7 +224,7 @@ static void AsyncUpdateThread(
217224
try {
218225
executor->RunPreparedContext(prepared,
219226
v.second->GetMutableLocalScope());
220-
} catch (std::exception &e) {
227+
} catch (const std::exception &e) {
221228
LOG(ERROR) << "run sub program error " << e.what();
222229
}
223230
});
@@ -236,15 +243,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
236243

237244
auto grad_to_block_id_str =
238245
Attr<std::vector<std::string>>("grad_to_block_id");
239-
for (auto &grad_and_id : grad_to_block_id_str) {
246+
for (const auto &grad_and_id : grad_to_block_id_str) {
240247
std::vector<std::string> pieces;
241248
split(grad_and_id, ':', &pieces);
242249
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
243250
PADDLE_ENFORCE_EQ(pieces.size(), 2);
244251
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
245252
int block_id = std::stoi(pieces[1]);
246253
grad_to_block_id[pieces[0]] = block_id;
247-
grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>();
254+
std::shared_ptr<detail::ReceivedQueue> queue =
255+
std::make_shared<detail::ReceivedQueue>();
256+
grad_to_queue[pieces[0]] = queue;
257+
// record blocking queue in SignalHandler
258+
SignalHandler::RegisterBlockingQueue(queue);
248259
id_to_grad[block_id] = pieces[0];
249260
}
250261
size_t num_blocks = program->Size();
@@ -276,9 +287,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
276287
executor, grad_to_prepared_ctx[grad_name].get());
277288
}));
278289
}
279-
280290
VLOG(3) << "RunAsyncLoop into while";
281-
while (!exit_flag) {
291+
while (!exit_flag && !SignalHandler::IsProgramExit()) {
282292
const detail::ReceivedMessage v = rpc_service_->Get();
283293
auto recv_var_name = v.first;
284294
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
@@ -333,6 +343,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
333343
VLOG(3) << "wait server thread to become ready...";
334344
rpc_service_->WaitServerReady();
335345

346+
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
347+
signal(SIGINT, SignalHandler::StopAndExit);
348+
signal(SIGTERM, SignalHandler::StopAndExit);
349+
336350
// Write to a file of server selected port for python use.
337351
std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
338352
static_cast<int>(::getpid()));
@@ -348,12 +362,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
348362
public:
349363
void Make() {
350364
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
351-
AddComment(R"DOC(
352-
ListenAndServ operator
353-
354-
This operator will start a RPC server which can receive variables
355-
from send_op and send back variables to recv_op.
356-
)DOC");
365+
AddComment(R"DOC(" + "ListenAndServ operator" + "\n" + "This operator" +
366+
" will start a RPC server which can receive variables from send_op and send" +
367+
"back variables to recv_op.)DOC");
357368
AddAttr<std::string>("endpoint",
358369
"(string, default 127.0.0.1:6164)"
359370
"IP address to listen on.")
@@ -374,6 +385,29 @@ from send_op and send back variables to recv_op.
374385
}
375386
};
376387

388+
bool SignalHandler::program_exit_flag_ = false;
389+
390+
SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
391+
392+
void SignalHandler::StopAndExit(int signal_num) {
393+
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
394+
395+
program_exit_flag_ = true;
396+
397+
// awake all blocking queues
398+
for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin();
399+
iter != blocking_queue_set_.end(); iter++) {
400+
iter->get()->Push(
401+
std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr));
402+
}
403+
404+
exit(EXIT_SUCCESS);
405+
}
406+
407+
void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) {
408+
blocking_queue_set_.insert(queue);
409+
}
410+
377411
} // namespace operators
378412
} // namespace paddle
379413

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License. */
1616

1717
#include <stdint.h>
1818
#include <atomic>
19-
#include <ostream>
19+
#include <set>
2020
#include <string>
2121

2222
#include "paddle/fluid/framework/executor.h"
@@ -40,6 +40,8 @@ class ListenAndServOp : public framework::OperatorBase {
4040
const framework::VariableNameMap& outputs,
4141
const framework::AttributeMap& attrs);
4242

43+
virtual ~ListenAndServOp();
44+
4345
void RunSyncLoop(framework::Executor* executor,
4446
framework::ProgramDesc* program,
4547
framework::Scope* recv_scope,
@@ -68,5 +70,25 @@ class ListenAndServOp : public framework::OperatorBase {
6870
static std::atomic_int selected_port_;
6971
};
7072

73+
class SignalHandler {
74+
public:
75+
typedef std::shared_ptr<detail::ReceivedQueue> BlockingQueue;
76+
typedef std::unordered_set<BlockingQueue> BlockingQueueSet;
77+
78+
public:
79+
static void StopAndExit(int signal_num);
80+
81+
static void RegisterBlockingQueue(BlockingQueue&);
82+
83+
static inline bool IsProgramExit() { return program_exit_flag_; }
84+
85+
private:
86+
static bool program_exit_flag_;
87+
88+
static BlockingQueueSet blocking_queue_set_;
89+
90+
DISABLE_COPY_AND_ASSIGN(SignalHandler);
91+
};
92+
7193
} // namespace operators
7294
} // namespace paddle

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,5 @@ foreach(TEST_OP ${TEST_OPS})
4848
endforeach(TEST_OP)
4949
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
5050
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
51+
# tests that need to be done in fixed timeout
52+
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.fluid as fluid
17+
import os
18+
import signal
19+
import subprocess
20+
import time
21+
import unittest
22+
from multiprocessing import Process
23+
from op_test import OpTest
24+
25+
26+
def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
27+
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
28+
y_predict = fluid.layers.fc(input=x, size=1, act=None)
29+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
30+
31+
# loss function
32+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
33+
avg_cost = fluid.layers.mean(cost)
34+
35+
# optimizer
36+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
37+
sgd_optimizer.minimize(avg_cost)
38+
39+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
40+
exe = fluid.Executor(place)
41+
42+
port = os.getenv("PADDLE_INIT_PORT", port)
43+
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip...
44+
eplist = []
45+
for ip in pserver_ips.split(","):
46+
eplist.append(':'.join([ip, port]))
47+
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
48+
trainers = int(os.getenv("TRAINERS", trainer_count))
49+
current_endpoint = os.getenv("POD_IP", ip) + ":" + port
50+
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id))
51+
t = fluid.DistributeTranspiler()
52+
t.transpile(
53+
trainer_id,
54+
pservers=pserver_endpoints,
55+
trainers=trainers,
56+
sync_mode=sync_mode)
57+
pserver_prog = t.get_pserver_program(current_endpoint)
58+
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
59+
exe.run(pserver_startup)
60+
exe.run(pserver_prog)
61+
62+
63+
class TestListenAndServOp(OpTest):
64+
def setUp(self):
65+
self.sleep_time = 5
66+
self.ip = "127.0.0.1"
67+
self.port = "6173"
68+
self.trainer_count = 1
69+
self.trainer_id = 1
70+
71+
def _raise_signal(self, parent_pid, raised_signal):
72+
time.sleep(self.sleep_time)
73+
ps_command = subprocess.Popen(
74+
"ps -o pid --ppid %d --noheaders" % parent_pid,
75+
shell=True,
76+
stdout=subprocess.PIPE)
77+
ps_output = ps_command.stdout.read()
78+
retcode = ps_command.wait()
79+
assert retcode == 0, "ps command returned %d" % retcode
80+
81+
for pid_str in ps_output.split("\n")[:-1]:
82+
try:
83+
os.kill(int(pid_str), raised_signal)
84+
except Exception:
85+
continue
86+
87+
def _start_pserver(self, use_cuda, sync_mode):
88+
p = Process(
89+
target=run_pserver,
90+
args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count,
91+
self.trainer_id))
92+
p.start()
93+
94+
def test_handle_signal_in_serv_op(self):
95+
# run pserver on CPU in sync mode
96+
self._start_pserver(False, True)
97+
98+
# raise SIGINT to pserver
99+
self._raise_signal(os.getpid(), signal.SIGINT)
100+
101+
# run pserver on CPU in async mode
102+
self._start_pserver(False, False)
103+
104+
# raise SIGTERM to pserver
105+
self._raise_signal(os.getpid(), signal.SIGTERM)
106+
107+
108+
if __name__ == '__main__':
109+
unittest.main()

0 commit comments

Comments
 (0)